我正在尝试实现这样的东西 https://www.youtube.com/watch?v=Fp9kzoAxsA4这是一个使用 DL4J 库的 GANN(遗传算法神经网络)。
遗传学习变量:
- 基因:生物神经网络权重
- 健康度:移动的总距离。
每个生物的神经网络层:
这是我的生物对象的 createBrain
方法:
private void createBrain() {
Layer inputLayer = new DenseLayer.Builder()
// 5 eye sensors
.nIn(5)
.nOut(5)
// How do I initialize custom weights using creature genes (this.genes)?
// .weightInit(WeightInit.ZERO)
.activation(Activation.RELU)
.build();
Layer outputLayer = new OutputLayer.Builder()
.nIn(5)
.nOut(1)
.activation(Activation.IDENTITY)
.lossFunction(LossFunctions.LossFunction.MSE)
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(6)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0,inputLayer)
.layer(1, outputLayer)
.pretrain(false).backprop(true)
.build();
this.brain = new MultiLayerNetwork(conf);
this.brain.init();
}
如果它可能有帮助,我已经推送到这个仓库 https://github.com/kareem3d/GeneticNeuralNetwork
这是 Creature 类 https://github.com/kareem3d/GeneticNeuralNetwork/blob/master/src/main/java/com/mycompany/gaan/Creature.java
我是一名机器学习学生,因此如果您发现任何明显的错误,请告诉我,谢谢:)
最佳答案
我不知道你是否可以在层配置中设置权重(我在API文档中看不到),但你可以在初始化模型后获取并设置网络参数。
要为图层单独设置它们,您可以按照此示例进行操作;
Iterator paramap_iterator = convolutionalEncoder.paramTable().entrySet().iterator();
while(paramap_iterator.hasNext()) {
Map.Entry<String, INDArray> me = (Map.Entry<String, INDArray>) paramap_iterator.next();
System.out.println(me.getKey());//print key
System.out.println(Arrays.toString(me.getValue().shape()));//print shape of INDArray
convolutionalEncoder.setParam(me.getKey(), Nd4j.rand(me.getValue().shape()));//set some random values
}
如果您想一次设置网络的所有参数,您可以使用 setParams()
和 params()
,例如;
INDArray all_params = convolutionalEncoder.params();
convolutionalEncoder.setParams(Nd4j.rand(all_params.shape()));//set random values with the same shape
您可以查看API了解更多信息; https://deeplearning4j.org/doc/org/deeplearning4j/nn/api/Model.html#params--
关于machine-learning - 在 deeplearning4j 中初始化自定义权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42806761/