python - 我在 sklearn 中没有顺利得到多项式回归

标签 python matplotlib scikit-learn

对于在 python 中使用 sklearn 实现的多项式回归代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
import pandas as pd

m=100
X=6*np.random.rand(m,1)-3
y=0.5*X**3+X+2+np.random.rand(m,1)

poly_features=PolynomialFeatures(3)
X_poly=poly_features.fit_transform(X)
lin_reg=LinearRegression()

X_new=[[0],[3]]

lin_reg.fit(X_poly,y)

plt.plot(X,y,"b.")
plt.plot(X, lin_reg.predict(poly_features.fit_transform(X)),  "r-")

plt.show()

输出显示为

enter image description here

但我想得到一条平滑的预测线。如何获得?

最佳答案

问题是 X 数组未排序。因此,当您使用线 -r 绘制数据时,它会按照未排序的 X 数据点的顺序连接数据点。因此,您会看到一个随机的线路网络。对于带有标记的绘图来说,顺序并不重要,因为您只是绘制没有线条的点。

解决方案是对 X 数据进行排序,并将排序后的 X 数据传递给绘图命令,并相应地传递给 fit_transform

shape = X.shape
X = np.sort(X.flatten())
plt.plot(X, lin_reg.predict(poly_features.fit_transform(X.reshape((shape)))),  "r-", lw=2)

enter image description here

关于python - 我在 sklearn 中没有顺利得到多项式回归,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54160912/

相关文章:

python - 为什么该图的底部有额外的空间?

scikit-learn - 可更新的最近邻搜索

python - scikit-learn 和 mllib 在预测 python 上的区别

Python 2 到 3 个字节/字符串错误

python - 在 python 2.7 上导入 tkinter 时出现问题

python - 在python中读取文件

matplotlib - 更改 Seaborn Factorplot 中的标记大小

python - 错误: (-215) reprojectImageTo3D opencv

python - Seaborn 和 Stripplot 的误差线

python - 决策树生成具有相同类的终端叶