python - 如何从多项式拟合中提取方程?

标签 python scikit-learn regression curve-fitting

我的目标是将一些数据拟合到多项式函数中,并获得包含拟合参数值的实际方程。

我改编了this example根据我的数据,结果符合预期。

这是我的代码:

import numpy as np
import matplotlib.pyplot as plt

from sklearn.linear_model import Ridge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline


x = np.array([0., 4., 9., 12., 16., 20., 24., 27.])
y = np.array([2.9,4.3,66.7,91.4,109.2,114.8,135.5,134.2])

x_plot = np.linspace(0, max(x), 100)
# create matrix versions of these arrays
X = x[:, np.newaxis]
X_plot = x_plot[:, np.newaxis]

plt.scatter(x, y, label="training points")

for degree in np.arange(3, 6, 1):
    model = make_pipeline(PolynomialFeatures(degree), Ridge())
    model.fit(X, y)
    y_plot = model.predict(X_plot)
    plt.plot(x_plot, y_plot, label="degree %d" % degree)

plt.legend(loc='lower left')

plt.show()

enter image description here

但是,我现在不知道从哪里提取实际方程和各个拟合的拟合参数值。我在哪里可以访问实际的拟合方程?

编辑:

变量 model 具有以下属性:

model.decision_function  model.fit_transform      model.inverse_transform  model.predict            model.predict_proba      model.set_params         model.transform          
model.fit                model.get_params         model.named_steps        model.predict_log_proba  model.score              model.steps

model.get_params 不存储所需的参数。

最佳答案

线性模型的系数存储在模型的intercept_coeff_属性中。

您可以通过调低正则化并输入已知模型来更清楚地看到这一点;例如

import numpy as np
from sklearn.linear_model import Ridge
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures

x = 10 * np.random.random(100)
y = -4 + 2 * x - 3 * x ** 2

model = make_pipeline(PolynomialFeatures(2), Ridge(alpha=1E-8, fit_intercept=False))
model.fit(x[:, None], y)
ridge = model.named_steps['ridge']
print(ridge.coef_)
# array([-4.,  2., -3.])

另请注意,PolynomialFeatures 默认包含偏置项,因此在 Ridge 中拟合截距对于较小的 alpha 来说是多余的。

关于python - 如何从多项式拟合中提取方程?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33876900/

相关文章:

python - 在scikit-learn中拟合后如何得到方程?

python - Scikit Learn 逻辑回归中预测的逆是正确的

r - 为什么 nls(非线性最小二乘法)在 R 中不起作用

r - 将回归线添加到多个散点图

python - Python 中的 SSH 动态端口转发 ('ssh -D')

python - 我如何让这个容器自行删除?

python - 无法使用 boto 将 SCP 文件发送到 AWS

python - Scikit-learn:避免高斯过程回归中的过度拟合

python - Python中LOWESS的置信区间

python - 计算列表中的重复项,并将该数字放入子列表中?