python - 尝试绘制一个简单的函数 - python

标签 python numpy plot machine-learning linear-regression

我实现了一个简单的线性回归,我想通过拟合非线性模型来尝试一下

具体来说,我正在尝试为函数 y = x^3 + 5 拟合一个模型

这是我的代码

import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt

def predict(X,W):
    return np.dot(X,W)

def gradient(X, Y, W, regTerm=0):
    return (-np.dot(X.T, Y) + np.dot(np.dot(X.T,X),W))/(m*k) + regTerm * W /(n*k)

def cost(X, Y, W, regTerm=0):
    m, k = Y.shape
    n, k = W.shape
    Yhat = predict(X, W)
    return np.trace(np.dot(Y-Yhat,(Y-Yhat).T))/(2*m*k) + regTerm * np.trace(np.dot(W,W.T)) / (2*n*k)

def Rsquared(X, Y, W):
    m, k = Y.shape
    SSres = cost(X, Y, W)
    Ybar = np.mean(Y,axis=0)
    Ybar = np.matlib.repmat(Ybar, m, 1)
    SStot = np.trace(np.dot(Y-Ybar,(Y-Ybar).T))

    return 1-SSres/SStot

m = 10
n = 200
k = 1

trX = np.random.rand(m, n)
trX[:, 0] = 1

for i in range(2, n):
    trX[:, i] = trX[:, 1] ** i

trY = trX[:, 1] ** 3 + 5
trY = np.reshape(trY, (m, k))

W = np.random.rand(n, k)

numIter = 10000
learningRate = 0.5

for i in range(0, numIter):
    W = W - learningRate * gradient(trX, trY, W)

domain = np.linspace(0,1,100000)
powerDomain = np.copy(domain)
m = powerDomain.shape[0]
powerDomain = np.reshape(powerDomain, (m, 1))
powerDomain = np.matlib.repmat(powerDomain, 1, n)

for i in range(1, n):
    powerDomain[:, i] = powerDomain[:, 0] ** i

print(Rsquared(trX, trY, W))
plt.plot(trX[:, 1],trY,'o', domain, predict(powerDomain, W),'r')
plt.show()

我得到的 R^2 非常接近 1,这意味着我发现非常适合训练数据,但它没有显示在图中。当我绘制数据时,它通常看起来像这样:

enter image description here

看起来好像我对数据拟合不足,但是有了如此复杂的假设,有 200 个特征(意味着我允许最多 x^200 的多项式)并且只有 10 个训练示例,我应该非常明显地过度拟合数据,所以我预计红线会穿过所有蓝点并在它们之间狂野。

这不是我得到的,这让我感到困惑。 怎么了?

最佳答案

您忘记设置 powerDomain[:,0]=1,这就是您的绘图在 0 处出错的原因。是的,你过度拟合了:看看一旦你离开你的训练域,你的情节就启动得有多快。

关于python - 尝试绘制一个简单的函数 - python,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38776418/

相关文章:

python - 当 `subprocess.call` 收到终止信号时会发生什么?

python - 如何链接一组切换按钮,如单选按钮?

python - 在模板 Django 中获取 Site_ID

python - Matplotlib 绘图以错误的方式绘制

python - 在 matplotlib pyplot 中,如何按类别对条形图中的条形进行分组?

python - 可能是基本的,但打印 pandas 变量的实际名称,而不是数据框本身

Python:im2col 的实现利用了 6 维数组的优势?

python - 拒绝特殊值的二维数组的计算(知道索引)

R:使用新的点()或线()添加更新图 [xy]lims?

c++ - 稀疏数据的C++/GNUPlot热图?