我有一个神经网络的基本框架来识别数字,但我在训练它时遇到了一些问题。我的反向传播适用于小型数据集,但当我有超过 50 个数据点时,返回值开始收敛到 0。当我有数千个数据集时,我得到 NaN 的成本和返回。
基本结构:3层:784 : 15 : 1
784 是每个数据集的像素数,我的隐藏层有 15 个神经元,一个输出神经元返回 0 到 1 的值(当你乘以 10 时得到一个数字)。
public class NetworkManager {
int inputSize;
int hiddenSize;
int outputSize;
public Matrix W1;
public Matrix W2;
public NetworkManager(int input, int hidden, int output) {
inputSize = input;
hiddenSize = hidden;
outputSize = output;
W1 = new Matrix(inputSize, hiddenSize);
W2 = new Matrix(hiddenSize, output);
}
Matrix z2, z3;
Matrix a2;
public Matrix forward(Matrix X) {
z2 = X.dot(W1);
a2 = sigmoid(z2);
z3 = a2.dot(W2);
Matrix yHat = sigmoid(z3);
return yHat;
}
public double costFunction(Matrix X, Matrix y) {
Matrix yHat = forward(X);
Matrix cost = yHat.sub(y);
cost = cost.mult(cost);
double returnValue = 0;
int i = 0;
while (i < cost.m.length) {
returnValue += cost.m[i][0];
i++;
}
return returnValue;
}
Matrix yHat;
public Matrix[] costFunctionPrime(Matrix X, Matrix y) {
yHat = forward(X);
Matrix delta3 = (yHat.sub(y)).mult(sigmoidPrime(z3));
Matrix dJdW2 = a2.t().dot(delta3);
Matrix delta2 = (delta3.dot(W2.t())).mult(sigmoidPrime(z2));
Matrix dJdW1 = X.t().dot(delta2);
return new Matrix[]{dJdW1, dJdW2};
}
}
这是网络框架的代码。我将长度为 784 的 double 组传递给 forward 方法。
int t = 0;
while (t < 10000) {
dJdW = Nn.costFunctionPrime(X, y);
Nn.W1 = Nn.W1.sub(dJdW[0].scalar(3));
Nn.W2 = Nn.W2.sub(dJdW[1].scalar(3));
t++;
}
我称之为调整权重。对于小集合,成本很好地收敛到 0,但大集合则不然(与 100 个字符相关的成本总是收敛到 13)。如果集合太大,第一次调整有效(并且成本下降)但第二次之后我只能得到 NaN。
为什么此实现会因更大的数据集(特别是训练)而失败,我该如何解决?我尝试了一个类似的结构,它有 10 个输出而不是 1 个输出,其中每个输出都会返回一个接近 0 或 1 的值,就像 boolean 值一样,但同样的事情发生了。
顺便说一句,我也在用 java 做这件事,我想知道这是否与问题有关。我想知道这是否是空间不足的问题,但我没有收到任何堆空间消息。我的反向传播方式有问题还是发生了其他事情?
编辑:我想我知道发生了什么。我认为我的反向传播函数陷入了局部最小值。对于大型数据集,有时训练成功,有时失败。因为我从随机权重开始,所以我得到随机的初始成本。我注意到的是,当成本最初超过一定数量(这取决于所涉及的数据集的数量)时,成本收敛到一个干净的数字(有时是 27,其他是 17.4)并且输出收敛到 0(这是有道理的).
开始时有人警告我成本函数的相对最小值,我开始意识到原因。所以现在问题变成了,我该如何进行梯度下降才能真正找到全局最小值?顺便说一下,我在 Java 工作。
最佳答案
这似乎是权重初始化的问题。
据我所知,您从未将权重初始化为任何特定值。因此网络发散。您至少应该使用随机初始化。
关于java - 神经网络和大数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41235524/