java - 有人可以检查我的异或神经网络代码有什么问题吗

标签 java machine-learning neural-network gradient-descent backpropagation

我一直在尝试创建一个 XOR 神经网络,但所有输入的输出始终会收敛到某个值(例如 1、0 或 0.5)。这是我最近的尝试:

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) {
        double[][] trainingInputs = {
                {1, 1},
                {1, 0},
                {0, 1},
                {1, 1}
        };
        double[] targetOutputs = {0, 1, 1, 0};
        NeuralNetwork network = new NeuralNetwork();
        System.out.println("Training");
        for(int i=0; i<40; i++) {
            network.train(trainingInputs, targetOutputs);
        }
        for(double[] inputs : trainingInputs) {
            double output = network.feedForward(inputs);
            System.out.println(inputs[0] + " - " + inputs[1] + " : " + output);
        }
    }
}

class Neuron {
    private ArrayList<Synapse> inputs; // List di sinapsi collegate al neurone
    private double output; // output del neurone
    private double derivative; // derivata dell'output
    private double weightedSum; // somma ponderata del peso delle sinapsi e degli output collegati
    private double error; // errore
    public Neuron() {
        inputs = new ArrayList<Synapse>();
        error = 0;
    }
    // Aggiunge una sinpapsi
    public void addInput(Synapse input) {
        inputs.add(input);
    }

    public List<Synapse> getInputs() {
        return this.inputs;
    }

    public double[] getWeights() {
        double[] weights = new double[inputs.size()];

        int i = 0;
        for(Synapse synapse : inputs) {
            weights[i] = synapse.getWeight();
            i++;
        }

        return weights;
    }

    private void calculateWeightedSum() {
        weightedSum = 0;
        for(Synapse synapse : inputs) {
            weightedSum += synapse.getWeight() * synapse.getSourceNeuron().getOutput();
        }
    }

    public void activate() {
        calculateWeightedSum();
        output = sigmoid(weightedSum);
        derivative = sigmoidDerivative(output);
    }

    public double getOutput() {
        return this.output;
    }

    public void setOutput(double output) {
        this.output = output;
    }

    public double getDerivative() {
        return this.derivative;
    }

    public double getError() {
        return error;
    }

    public void setError(double error) {
        this.error = error;
    }

    public double sigmoid(double weightedSum) {
        return 1 / (1 + Math.exp(-weightedSum));
    }

    public double sigmoidDerivative(double output) {
        return output / (1 - output);
    }
}

class Synapse implements Serializable {

    private Neuron sourceNeuron; // Neurone da cui origina la sinapsi
    private double weight; // Peso della sinapsi

    public Synapse(Neuron sourceNeuron) {
        this.sourceNeuron = sourceNeuron;
        this.weight = Math.random() - 0.5;
    }

    public Neuron getSourceNeuron() {
        return sourceNeuron;
    }

    public double getWeight() {
        return weight;
    }

    public void adjustWeight(double deltaWeight) {
        this.weight += deltaWeight;
    }
}

class NeuralNetwork implements Serializable {
    Neuron[] input;
    Neuron[] hidden;
    Neuron output;
    double learningRate = 0.1;
    public NeuralNetwork() {
        input = new Neuron[2];
        hidden = new Neuron[2];
        output = new Neuron();
        for(int i=0; i<2; i++) {
            input[i] = new Neuron();
        }
        for(int i=0; i<2; i++) {
            hidden[i] = new Neuron();
        }
        for(int i=0; i<2; i++) {
            Synapse s = new Synapse(hidden[i]);
            output.addInput(s);
        }
        for(int i=0; i<2; i++) {
            for(int j=0; j<2; j++) {
                Synapse s = new Synapse(input[j]);
                hidden[i].addInput(s);
            }
        }
    }
    public void setInput(double[] inputVal) {
        for(int i=0; i<2; i++) {
            input[i].setOutput(inputVal[i]);
        }
    }
    public double feedForward(double[] inputVal) {
        setInput(inputVal);
        for(int i=0; i<2; i++) {
            hidden[i].activate();
        }
        output.activate();
        return output.getOutput();
    }
    public void train(double[][] trainingInputs, double[] targetOutputs) {
        for(int i=0; i<4; i++) {
            double[] inputs = trainingInputs[i];
            double target = targetOutputs[i];
            double currentOutput = feedForward(inputs);
            double delta = 0;
            double neuronError = 0;
            for(int j=0; j<2; j++) {
                Synapse s = output.getInputs().get(j);
                neuronError = output.getDerivative() * (target - currentOutput);
                delta = learningRate * s.getSourceNeuron().getOutput() * neuronError;
                output.setError(neuronError);
                s.adjustWeight(delta);
            }
            for(int j=0; j<2; j++) {
                for(int k=0; k<2; k++) {
                    Synapse s = hidden[j].getInputs().get(k);
                    Synapse s1 = output.getInputs().get(j);
                    delta = learningRate * s.getSourceNeuron().getOutput() * hidden[j].getDerivative() * s1.getWeight() * output.getError();
                    s.adjustWeight(delta);
                }
            }
        }
    }
}

我从 github 上其他人的实现中找到了反向传播算法,并尝试使用它,但我要么得到 0.50 左右的输出,要么只是 NaN。我不知道我是否使用了错误的算法,是否以错误的方式实现了它或其他什么。

我使用的算法是这样的: 首先我找到神经元本身的错误:

如果是输出神经元,则 NeuronError = (输出神经元的导数) * (预期输出 - 实际输出)

如果它是隐藏神经元,则神经元误差=(隐藏神经元的导数)*(输出神经元的神经元误差)*(从隐藏神经元到输出神经元的突触权重)

然后 deltaWeight = LearningRate * (突触起始神经元的神经元误差) * (突触起始神经元的输出)

最后我将 deltaWeight 添加到之前的权重上。

抱歉,文字很长,如果您不通读代码,您至少可以告诉我我的算法是否正确吗?谢谢

最佳答案

你的 sigmoid 导数是错误的,应该如下:

public double sigmoidDerivative(double output) {
        return output * (1 - output);
    }
}

正如我在评论中所说,您的训练输入中有两次 {1, 1} , 因此将其更改为 {0, 0}。

最后,将迭代次数从 40 增加到 100,000。

关于java - 有人可以检查我的异或神经网络代码有什么问题吗,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58902225/

相关文章:

java - 如果通过删除整个映射关系来更改 hbm.xml 文件,是否仍然可以使用 Hibernate 标准?

machine-learning - 在 Caffe 中使用可学习参数编写自定义 Python 层

neural-network - 通过神经网络进行时间序列预测

neural-network - 预训练的 GloVe 矢量文件(例如 glove.6B.50d.txt)中的 "unk"是什么?

java - HackerRank 最接近的数字

java - Eclipse - 错误 : Main method not found in class projectOne, 请将 main 方法定义为:public static void main(String[] args)

java - Spring Boot 应用程序在 STS 中工作但无法使用 java -jar 启动

r - 尝试在 R 中运行 kNN 时,我收到由 coercionNAs 引入的错误 NAs?

python - MongoDB + K 表示集群

neural-network - Caffe 分割网络 - softmax_loss_layer 错误