python - 回归模型中成本函数的 L1 范数而不是 L2 范数

标签 python machine-learning regression least-squares

我想知道 Python 中是否有一个函数可以完成与 scipy.linalg.lstsq 相同的工作,但使用“最小绝对偏差”回归而不是“最小二乘”回归 (OLS)。我想使用 L1 规范,而不是 L2 规范。

事实上,我有 3d 个点,我想要它们中最合适的平面。常用的方法是像 Github 这样的最小二乘法 link .但众所周知,这并不总是最合适的,尤其是当我们的数据集中有闯入者时。最好计算最小的绝对偏差。两种方法的区别解释更多here .

它不会被诸如 MAD 之类的函数求解,因为它是一个 Ax = b 矩阵方程并且需要循环来最小化结果。我想知道是否有人知道 Python 中的相关函数 - 可能在线性代数包中 - 可以计算“最小绝对偏差”回归?

最佳答案

使用 scipy.optimize.minimize 和自定义 cost_function,这并不难。

让我们先进口必需品,

from scipy.optimize import minimize
import numpy as np

并定义自定义成本函数(以及用于获取拟合值的便利包装器),

def fit(X, params):
    return X.dot(params)


def cost_function(params, X, y):
    return np.sum(np.abs(y - fit(X, params)))

然后,如果您有一些X(设计矩阵)和y(观察),我们可以执行以下操作,

output = minimize(cost_function, x0, args=(X, y))

y_hat = fit(X, output.x)

x0 是最佳参数的一些合适的初始猜测(您可以在此处采纳@JamesPhillips 的建议,并使用 OLS 方法中的拟合参数)。

无论如何,当用一个有点人为的例子进行测试时,

X = np.asarray([np.ones((100,)), np.arange(0, 100)]).T
y = 10 + 5 * np.arange(0, 100) + 25 * np.random.random((100,))

我发现,

      fun: 629.4950595335436
 hess_inv: array([[  9.35213468e-03,  -1.66803210e-04],
       [ -1.66803210e-04,   1.24831279e-05]])
      jac: array([  0.00000000e+00,  -1.52587891e-05])
  message: 'Optimization terminated successfully.'
     nfev: 144
      nit: 11
     njev: 36
   status: 0
  success: True
        x: array([ 19.71326758,   5.07035192])

还有,

fig = plt.figure()
ax = plt.axes()

ax.plot(y, 'o', color='black')
ax.plot(y_hat, 'o', color='blue')

plt.show()

蓝色为拟合值,黑色为数据。

enter image description here

关于python - 回归模型中成本函数的 L1 范数而不是 L2 范数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51883058/

相关文章:

python - 属性错误: object has no attribute '_lazy_read'

python - 在 sklearn 中训练后是否必须再次使用 fit() ?

machine-learning - 为什么我的神经网络无法正确分类这些井字游戏模式?

python - 使用正规方程的线性回归

python - 如何在截距条件下拟合多项式(使用 np.polyfit 或其他东西)?

machine-learning - 异常: The passed model is not callable and cannot be analyzed directly with the given masker

python - 尝试从列表中删除一组元组时遇到问题?

python - Django model.foreignKey 并返回 self.text 错误

python - 跟踪快速更新的文件时程序崩溃

python - 如何解析 .shp 文件?