python - 线性回归模型

标签 python python-3.x numpy machine-learning regression

以下是我使用 SGD 实现的线性回归,但获得的直线不是最合适的。我该如何改进它? enter image description here

import matplotlib.pyplot as plt
from matplotlib import style
import numpy as np

style.use("fivethirtyeight")

x=[[1],[2],[3],[4],[5],[6],[7],[8],[9],[10]]
y=[[3],[5],[9],[9],[11],[13],[16],[17],[19],[21]]

X=np.array(x)
Y=np.array(y)

learning_rate=0.015


m=1
c=2
gues=[]

for i in range(len(x)):

    guess=m*x[i][0]+c
    error=guess-y[i][0]


    if error<0:

        m=m+abs(error)*x[i][0]*learning_rate
        c=c+abs(error)*learning_rate

    if error>0:

        m=m-abs(error)*x[i][0]*learning_rate
        c=c-abs(error)*learning_rate
    gues.append([guess])
t=np.array(gues)



plt.scatter(X,Y)
plt.plot(X,t)
plt.show()


from sklearn.linear_model import LinearRegression
var=LinearRegression()
var.fit(X,Y)
plt.scatter(X,Y)
plt.plot(X,var.predict(X))
plt.show()

因为我必须最小化错误,这是(猜测-y)对错误函数 w.r.t 的偏导数到 m 给出 x 和 w.r.t c 给出一个常量。

最佳答案

您正在进行随机梯度下降,评估每个数据点的拟合度。所以最后的 mc 给你拟合关系的参数。您绘制的线是拟合线的“演变”。

这是我绘制它的方式,当我弄清楚你在做什么时对你的代码进行了一些其他调整:

import numpy as np
import matplotlib.pyplot as plt

X = np.array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
Y = np.array([ 3,  5,  9,  9, 11, 13, 16, 17, 19, 21])

learning_rate = 0.015
m = 1
c = 2

gues = []
for xi, yi in zip(X, Y):

    guess = m * xi + c
    error = guess - yi

    m = m - error * xi * learning_rate
    c = c - error * learning_rate

    gues.append(guess)

t = np.array(gues)

# Plot the modeled line.
y_hat = m * X + c
plt.figure(figsize=(10,5))
plt.plot(X, y_hat, c='red')

# Plot the data.
plt.scatter(X, Y)

# Plot the evolution of guesses.
plt.plot(X, t)
plt.show()

enter image description here

我在代码中所做的主要修改是:跨过压缩的 XY 这样您就可以使用 then 而无需索引它们。为了简单起见,我也将它们设为一维数组。如果您直接使用渐变,而不使用 abs,则 +ve 和 -ve 情况下不需要不同的路径。

关于python - 线性回归模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48607884/

相关文章:

python - 无法理解 @property 如何知道哪个变量是属性?

Python "regex"模块 : Fuzziness value

python - 如何通过python中的给定字符串生成随机数?

python-3.x - 无法使用 Python 3.8 在 Ubuntu 18.04 上创建 virtualenv

python-2.7 - 无法让 scipy.io.wavfile.read() 工作

Python 二维高斯拟合与数据中的 NaN 值

python - Scrapy 和框架

python - 如何处理SessionNotCreatedException : Message: session not created exception from disconnected: Unable to receive message from renderer?

python - 在python中将32位二进制转换为十进制

python - 相当于 MATLAB 元胞数组的 Numpy