regression - 梯度下降不收敛

标签 regression matlab gradient-descent

这是我自己用matlab语言实现的梯度下降算法

 m = height(data_training); % number of samples
cols = {'x1', 'x2', 'x3', 'x4', 'x5', 'x6',...
    'x7', 'x8','x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15'}; 

y = data_training{:, {'y'}}';
X = [ones(m,1) data_training{:,cols}]'; 

theta = zeros(1,width(data_training));

alpha = 1e-2; % learning rate
iter = 400;

dJ = zeros(1,width(data_training));

J_seq = zeros(1, iter);

for n = 1:iter

    err = (theta*X - y);

    for j = 1:width(data_training)
        dJ(j) = 1/m*sum(err*X(j,:)');
    end

    J = 1/(2*m)*sum((theta*X-y).^2);

    theta = theta - alpha.*dJ;

    J_seq(n) = J;

    if mod(n,100) == 0
        plot(1:iter, J_seq);
    end
end

编辑 工作算法

我已将此算法应用于以下训练数据集。最后一列是输出变量。这里我们有 15 个不同的功能。

由于我不知道的原因,当我在 50 次迭代后绘制成本函数 J 以检查它是否趋于收敛时,我发现它没有收敛。你能帮我理解一下吗?是实现错误还是我应该做点什么?

36    27    71     8.1    3.34    11.4    81.5    3243     8.8    42.6    11.7     21     15     59    59     921.87
35    23    72    11.1    3.14      11    78.8    4281     3.6    50.7    14.4      8     10     39    57     997.88
44    29    74    10.4    3.21     9.8    81.6    4260     0.8    39.4    12.4      6      6     33    54     962.35
47    45    79     6.5    3.41    11.1    77.5    3125    27.1    50.2    20.6     18      8     24    56     982.29
43    35    77     7.6    3.44     9.6    84.6    6441    24.4    43.7    14.3     43     38    206    55     1071.3
53    45    80     7.7    3.45    10.2    66.8    3325    38.5    43.1    25.5     30     32     72    54     1030.4
43    30    74    10.9    3.23    12.1    83.9    4679     3.5    49.2    11.3     21     32     62    56      934.7
45    30    73     9.3    3.29    10.6      86    2140     5.3    40.4    10.5      6      4      4    56     899.53
36    24    70       9    3.31    10.5    83.2    6582     8.1    42.5    12.6     18     12     37    61     1001.9
36    27    72     9.5    3.36    10.7    79.3    4213     6.7      41    13.2     12      7     20    59     912.35
52    42    79     7.7    3.39     9.6    69.2    2302    22.2    41.3    24.2     18      8     27    56     1017.6
33    26    76     8.6     3.2    10.9    83.4    6122    16.3    44.9    10.7     88     63    278    58     1024.9
40    34    77     9.2    3.21    10.2      77    4101      13    45.7    15.1     26     26    146    57     970.47
35    28    71     8.8    3.29    11.1    86.3    3042    14.7    44.6    11.4     31     21     64    60     985.95
37    31    75       8    3.26    11.9    78.4    4259    13.1    49.6    13.9     23      9     15    58     958.84
35    46    85     7.1    3.22    11.8    79.9    1441    14.8    51.2    16.1      1      1      1    54      860.1
36    30    75     7.5    3.35    11.4    81.9    4029    12.4      44      12      6      4     16    58     936.23
15    30    73     8.2    3.15    12.2    84.2    4824     4.7    53.1    12.7     17      8     28    38     871.77
31    27    74     7.2    3.44    10.8      87    4834    15.8    43.5    13.6     52     35    124    59     959.22
30    24    72     6.5    3.53    10.8    79.5    3694    13.1    33.8    12.4     11      4     11    61     941.18
31    45    85     7.3    3.22    11.4    80.7    1844    11.5    48.1    18.5      1      1      1    53     891.71
31    24    72       9    3.37    10.9    82.8    3226     5.1    45.2    12.3      5      3     10    61     871.34
42    40    77     6.1    3.45    10.4    71.8    2269    22.7    41.4    19.5      8      3      5    53     971.12
43    27    72       9    3.25    11.5    87.1    2909     7.2    51.6     9.5      7      3     10    56     887.47
46    55    84     5.6    3.35    11.4    79.7    2647      21    46.9    17.9      6      5      1    59     952.53
39    29    76     8.7    3.23    11.4    78.6    4412    15.6    46.6    13.2     13      7     33    60     968.66
35    31    81     9.2     3.1      12    78.3    3262    12.6    48.6    13.9      7      4      4    55     919.73
43    32    74    10.1    3.38     9.5    79.2    3214     2.9    43.7      12     11      7     32    54     844.05
11    53    68     9.2    2.99    12.1    90.6    4700     7.8    48.9    12.3    648    319    130    47     861.83
30    35    71     8.3    3.37     9.9    77.4    4474    13.1    42.6    17.7     38     37    193    57     989.26
50    42    82     7.3    3.49    10.4    72.5    3497    36.7    43.3    26.4     15     10     34    59     1006.5
60    67    82      10    2.98    11.5    88.6    4657    13.6    47.3    22.4      3      1      1    60     861.44
30    20    69     8.8    3.26    11.1    85.4    2934     5.8      44     9.4     33     23    125    64     929.15
25    12    73     9.2    3.28    12.1    83.1    2095       2    51.9     9.8     20     11     26    50     857.62
45    40    80     8.3    3.32    10.1    70.3    2682      21    46.1    24.1     17     14     78    56     961.01
46    30    72    10.2    3.16    11.3    83.2    3327     8.8    45.3    12.2      4      3      8    58     923.23

最佳答案

不确定我是否遵循您的逻辑,但很明显“e”(错误)不应该平方

让我们看看您应该使用什么。

theta 是未知数的列向量,y 是测量值的列向量,X 是模型矩阵,其中每行都是一个'例子'。所以你需要找到 theta 这样:

y = X*theta 

或者等价地,使用优化方法来找到最小化当前平方误差的theta(这就是凸优化问题的原因):

e[n] = (y - X*theta[n])

e[n]^2 --> minimize 

梯度下降使用误差函数的梯度(相对于 theta)来更新 theta 向量:

theta[n+1] = theta[n] - alpha*2*X'*e[n]

(请注意,e[n] 和 theta[n] 是向量。这是数学符号 - 不是 matlab 的)

所以你会看到 e[n] 在更新方程中没有平方。

关于regression - 梯度下降不收敛,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27575792/

相关文章:

python - cross_val_score 和 cross_val_predict 的区别

open-source - 开源(或其他)版本的 Matlab 工具箱

neural-network - 为什么我们需要显式调用 zero_grad()?

machine-learning - Keras 中的自定义损失函数用于惩罚漏报

c - 使用Mex环境时设置环境变量

sockets - 权重在此代码中更新的位置?

Python pyGPs : IndexError for regression with multi-dimensional x and z, setData(x,y) 和预测(z)

r - 访问回归训练模型

使用 lm() 和 predict() 进行滚动回归和预测

matlab - Matlab 中基于傅里叶的字符识别