python - sklearn中的LinearRegression方法中,fit_intercept参数到底是干什么用的?

标签 python scikit-learn linear-regression

<分区>

sklearn.linear_model.LinearRegression 方法中,有一个参数是fit_intercept = TRUEfit_intercept = FALSE。我想知道如果我们将它设置为 TRUE,它是否会向您的数据集添加一个全 1 的附加截距列?如果我已经有一个包含一列 1 的数据集,fit_intercept = FALSE 是否说明了这一点,还是强制它拟合零截距模型?

更新:似乎人们没有理解我的问题。问题是,如果我的预测变量数据集中已经有一列 1(1 代表截距)怎么办?那么,

  1. 如果我使用 fit_intercept = FALSE,它会删除 1 的列吗?

  2. 如果我使用 fit_intercept = TRUE,它会添加一个额外的 1 列吗?

最佳答案

fit_intercept=False 将 y 截距设置为 0。如果 fit_intercept=True,则 y 截距将由最佳拟合线确定。

from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
import numpy as np
import matplotlib.pyplot as plt

bias = 100

X = np.arange(1000).reshape(-1,1)
y_true = np.ravel(X.dot(0.3) + bias)
noise = np.random.normal(0, 60, 1000)
y = y_true + noise

lr_fi_true = LinearRegression(fit_intercept=True)
lr_fi_false = LinearRegression(fit_intercept=False)

lr_fi_true.fit(X, y)
lr_fi_false.fit(X, y)

print('Intercept when fit_intercept=True : {:.5f}'.format(lr_fi_true.intercept_))
print('Intercept when fit_intercept=False : {:.5f}'.format(lr_fi_false.intercept_))

lr_fi_true_yhat = np.dot(X, lr_fi_true.coef_) + lr_fi_true.intercept_
lr_fi_false_yhat = np.dot(X, lr_fi_false.coef_) + lr_fi_false.intercept_

plt.scatter(X, y, label='Actual points')
plt.plot(X, lr_fi_true_yhat, 'r--', label='fit_intercept=True')
plt.plot(X, lr_fi_false_yhat, 'r-', label='fit_intercept=False')
plt.legend()

plt.vlines(0, 0, y.max())
plt.hlines(bias, X.min(), X.max())
plt.hlines(0, X.min(), X.max())

plt.show()

这个例子打印:

Intercept when fit_intercept=True : 100.32210
Intercept when fit_intercept=False : 0.00000

fit_intercept 的作用在视觉上变得很清楚。当 fit_intercept=True 时,允许最佳拟合线“拟合”y 轴(在本例中接近 100)。当 fit_intercept=False 时,拦截被强制到原点 (0, 0)。

fit_intercept in sklearn


What happens if I include a column of ones or zeros and set fit_intercept to True or False?

下面显示了如何检查它的示例。

from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(1)
bias = 100

X = np.arange(1000).reshape(-1,1)
y_true = np.ravel(X.dot(0.3) + bias)
noise = np.random.normal(0, 60, 1000)
y = y_true + noise

# with column of ones
X_with_ones = np.hstack((np.ones((X.shape[0], 1)), X))

for b,data in ((True, X), (False, X), (True, X_with_ones), (False, X_with_ones)):
  lr = LinearRegression(fit_intercept=b)
  lr.fit(data, y)

  print(lr.intercept_, lr.coef_)

外卖:

# fit_intercept=True, no column of zeros or ones
104.156765787 [ 0.29634031]
# fit_intercept=False, no column of zeros or ones
0.0 [ 0.45265361]
# fit_intercept=True, column of zeros or ones
104.156765787 [ 0.          0.29634031]
# fit_intercept=False, column of zeros or ones
0.0 [ 104.15676579    0.29634031]

关于python - sklearn中的LinearRegression方法中,fit_intercept参数到底是干什么用的?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46779605/

相关文章:

python - Windows 上的 Pygame 安装 - 错误 : Unable to find vcvarsall. bat

python - 搜索可用于期望最大化(EM)算法的 python 源代码?

python - 使用 Python 的 scikit-learn 中随机森林算法的置信度与概率

tensorflow - 最小化前馈神经网络的tensorflow.js中的损失

python - 尝试拟合回归模型时出现 ValueError

python - 索引错误: list assignment index out of range in Python

python - 每行最小值,Python Pandas

python - 无法更新标签文本

python - 绘制 DecisionTreeClassifier 的多类 ROC 曲线

scikit-learn - 不了解错误消息(基本sklearn命令)