java - 使用 nd4j 的 S 形导数

标签 java python python-3.x numpy nd4j

我执行 sigmoid,效果很好,但 sigmoidDerivative 给出的结果与 nd4j 中的 sigmoid 相同。 Transforms.sigmoidDerivative(x)Transforms.sigmoidDerivative(x, true) 之间有什么区别?

INDArray x = Nd4j.create(new double[] { 0.1812, 0.1235, 0.8466 });
System.out.println(x);
System.out.println(Transforms.sigmoid(x));
System.out.println(Transforms.sigmoidDerivative(x));
System.out.println(Transforms.sigmoidDerivative(x, true));

给出输出:

[[    0.1812,    0.1235,    0.8466]]
[[    0.5452,    0.5308,    0.6999]]
[[    0.5452,    0.5308,    0.6999]]
[[    0.2480,    0.2490,    0.2101]]

与 python 的 numpy 比较:

>>> def sigmoid(x):
...     return 1.0 / (1 + np.exp(-x))
... 
>>> def sigmoid_derivative(x):
...     a = sigmoid(x)
...     return a * (1.0 - a)

>>> x = np.array([    0.1812,    0.1235,    0.8466])
>>> sigmoid(x)
array([0.54517646, 0.53083582, 0.69985343])
>>> sigmoid_derivative(x)
array([0.24795909, 0.24904915, 0.21005861])

Nd4j pom:

<dependency>
     <groupId>org.nd4j</groupId>
     <artifactId>nd4j-native-platform</artifactId>
     <version>1.0.0-beta3</version>
</dependency>

最佳答案

你是对的,Transforms.sigmoidDerivative(x)Transforms.sigmoidDerivative(x, true)应该给出相同的结果,这是dl4j中的一个错误。正确的行为有后一种方法。我已经提交了pull request来解决它。

关于java - 使用 nd4j 的 S 形导数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53943251/

相关文章:

python - 值错误 : invalid literal for int() with base 10 for non-digits

java - 通过命令行编译并执行尝试查找 Java 包时显示 NoClassDefFoundError

java - JOGL 纹理不起作用

java - 多线程和持久性最佳实践

python - 无法在 Windows 上导入 distutils.dir_util

python - 将 win32lfn 扩展与 Mercurial 捆绑在一起

python-3.x - ModuleNotFoundError : No module named 'past' when installing tensorboard with pytorch 1. 2

java - 给十六进制添加填充的目的是什么?

python - Django - OneToOneField 未在管理中填充

python - 高效查询多列 SQLAlchemy ORM