python - 为什么在 LSTM 中添加 relu 激活后出现 Nan?

标签 python tensorflow lstm relu

我有一个简单的 LSTM 网络,大致如下所示:

lstm_activation = tf.nn.relu

cells_fw = [LSTMCell(num_units=100, activation=lstm_activation), 
            LSTMCell(num_units=10, activation=lstm_activation)]

stacked_cells_fw = MultiRNNCell(cells_fw)

_, states = tf.nn.dynamic_rnn(cell=stacked_cells_fw,
                              inputs=embedding_layer,
                              sequence_length=features['length'],
                              dtype=tf.float32)

output_states = [s.h for s in states]
states = tf.concat(output_states, 1)

我的问题是。当我不使用激活(激活=无)或使用 tanh 时,一切正常,但当我切换 relu 时,我不断收到“训练期间 NaN 损失”,这是为什么? 100% 可重复。

最佳答案

当您使用relu activation function时里面lstm cell ,保证单元的所有输出以及单元状态都严格为 >= 0 。因此,你的梯度变得非常大并且呈爆炸式增长。例如,运行以下代码片段并观察输出永远不会 < 0 .

X = np.random.rand(4,3,2)
lstm_cell = tf.nn.rnn_cell.LSTMCell(5, activation=tf.nn.relu)
hidden_states, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=X, dtype=tf.float64)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(hidden_states))

关于python - 为什么在 LSTM 中添加 relu 激活后出现 Nan?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55322991/

相关文章:

python - 使用琶音处理空格或逗号分隔的标记列表

python - 如何将数据保存到 csv Cantera 和错误 <cantera.composite.SolutionArray object at 0x7f4badca0fd0>

keras - LSTM 编码器-解码器推理模型

tensorflow - 为什么 LSTMCell 输入的维度必须与单元数匹配

python - 如何计算android相机捕获的图像中物体的深度

python - 获取给定列中使用的唯一字符列表

c++ - 如何使用 C++ 更改 tensorflow 中的 per_process_gpu_memory_fraction?

tensorflow - 多维 lstm tensorflow

python - 如何从自定义神经网络模型中获取对数和概率

python - keras cnn_lstm 输入层不接受一维输入