java - 线性回归中的梯度下降

标签 java machine-learning linear-regression gradient-descent

我正在尝试用java实现线性回归。我的假设是 theta0 + theta1 * x[i]。 我试图找出 theta0 和 theta1 的值,以使成本函数最小。 我正在使用梯度下降来找出值 -

while(repeat until convergence)
{
   calculate theta0 and theta1 simultaneously.
}

收敛之前的重复是什么? 我知道这是局部最小值,但是我应该在 while 循环中放入的确切代码是什么?

我对机器学习非常陌生,刚刚开始编写基本算法以获得更好的理解。任何帮助将不胜感激。

最佳答案

梯度下降是一种最小化给定函数的迭代方法。我们从对解的初始猜测开始,并在该点获取函数的梯度。我们将解向梯度的负方向步进,然后重复该过程。该算法最终将在梯度为零(对应于局部最小值)的地方收敛。因此,您的工作是找出使损失函数最小化的 theta0 和 theta1 的值(例如最小二乘误差)。 术语“收敛”意味着您达到了局部最小值,并且进一步的迭代不会影响参数的值,即 theta0 和 theta1 的值保持不变。让我们看一个示例注意:假设它位于第一象限以进行解释。

enter image description here

假设您必须最小化函数 f(x) [您的情况下的成本函数]。为此,您需要找出使 f(x) 函数值最小化的 x 值。这是使用梯度下降法找出 x 值的分步过程

  1. 您选择 x 的初始值。假设它位于图中的 A 点。
  2. 计算 f(x) 相对于 x 在 A 处的梯度。
  3. 这给出了函数在 A 点的斜率。由于函数在 A 处递增,因此它将产生正值。
  4. 您从 x 的初始猜测中减去该正值并更新 x 的值。即x = x - [某个正值]。这使得 x 更接近 D [即最小值]并减少f(x)的函数值[从图中]。假设在迭代 1 后,您到达点 B。
  5. 在 B 点,重复步骤 4 中提到的相同过程,到达 C 点,最后到达 D 点。
  6. 在 D 点,由于它是局部最小值,因此当计算梯度时,您会得到 0 [或非常接近 0]。现在您尝试更新 x 的值,即 x = x - [0]。您将得到相同的 x [或与前一个 x 非常接近的值]。这种情况称为“收敛”。 上述步骤适用于增加坡度,但对于减小坡度同样有效。例如,G 点的梯度会导致某个负值。当您更新 x 时,即 x = x - [ 负值] = x - [ - 某个正值] = x + 某个正值。这会增加 x 的值,并使 x 接近点 F [或接近最小值]。

有多种方法可以解决这种梯度下降问题。正如 @mattnedrich 所说,两种基本方法是

  1. 使用固定的迭代次数 N,此伪代码将为

    iter = 0
    while (iter < N) {
      theta0 = theta0 - gradient with respect to theta0
      theta1 = theta1 - gradient with respect to theta1
      iter++
    }
    
  2. 重复直到 theta0 和 theta1 的两个连续值几乎相同。伪代码由@Gerwin 在另一个答案中给出。

梯度下降是线性回归中最小化函数的方法之一。也存在直接的解决方案。批处理(也称为正规方程)可用于一步找出 theta0 和 theta1 的值。如果X是输入矩阵,y是输出 vector ,theta是你想要计算的参数,那么对于平方误差方法,你可以使用这个矩阵方程一步求出theta的值

theta = inverse(transpose (X)*X)*transpose(X)*y

但是由于这包含矩阵计算,显然当矩阵 X 的大小很大时,它比梯度下降的计算成本更高。 我希望这可以回答您的疑问。如果没有,请告诉我。

关于java - 线性回归中的梯度下降,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/21064030/

相关文章:

java - 重构现有模式

java - 由于数组索引导致的字符串异常

python - 如何将 GridSearchCV 中的验证集与训练集分开标准化?

machine-learning - Python 中的梯度下降实现返回 Nan

java - 计算指数增长系列中的值之和

Java ADB shell 不返回任何内容

machine-learning - 将 prop 文件转换为 arff 文件

python - 使用相同来源的余弦相似度和完全不同的结果

r - 从 R 中的线性模型列表映射 emmeans

python - 了解统计模型线性回归