python - 用于噪声序列的 Tensorflow LSTM

标签 python machine-learning tensorflow lstm

我尝试解决原始 LSTM 论文中描述的实验 3a:http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf使用 tensorflow LSTM 并失败

摘自论文:任务是观察输入序列,然后对其进行分类。有两个类别,每个类别发生的概率为 0.5。只有一根输入线。只有前 N 个实值序列元素传达有关该类的相关信息。位置 t > N 处的序列元素由均值为零、方差为 0.2 的高斯函数生成。

他在论文中描述的网络架构: “我们使用具有 1 个输入单元、1 个输出单元和 3 个大小为 1 的单元 block 的 3 层网络。输出层仅接收来自存储单元的连接。存储单元和门单元接收来自输入单元、存储单元和门的输入门单元和输出单元为 [0; 1] 中的逻辑 sigmoid,[-1; 1] 中为 h,[-2; 2] 中为 g"

我尝试使用具有 3 个隐藏单元(T=100 且 N=3)的 LSTM 重现它,但失败了。

我使用了在线训练(即在每个序列后更新权重),如原始论文中所述

我的代码核心如下:

self.batch_size = batch_size = config.batch_size
hidden_size = 3
self._input_data = tf.placeholder(tf.float32, (1, T))
self._targets = tf.placeholder(tf.float32, [1, 1])
lstm_cell = rnn_cell.BasicLSTMCell(hidden_size , forget_bias=1.0)
cell = rnn_cell.MultiRNNCell([lstm_cell] * 1)
self._initial_state = cell.zero_state(1, tf.float32)
weights_hidden = tf.constant(1.0, shape= [config.num_features, config.n_hidden])

准备输入

inputs = []
for k in range(num_steps):
   nextitem = tf.matmul(tf.reshape(self._input_data[:, k], [1, 1]) , weights_hidden)
   inputs.append(nextitem)

outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)

使用最后的输出

pred = tf.sigmoid(tf.matmul(outputs[-1], tf.get_variable("weights_out", [config.n_hidden,1])) + tf.get_variable("bias_out", [1]))

self._final_state = states[-1]
self._cost = cost = tf.reduce_mean(tf.square((pred - self.targets)))
self._result = tf.abs(pred[0, 0] - self.targets[0,0])

optimizer = tf.train.GradientDescentOptimizer(learning_rate = config.learning_rate).minimize(cost)

知道为什么它无法学习吗?

我的第一直觉是为每个类创建 2 个输出,但在论文中他特别提到了仅一个输出单元。

谢谢

最佳答案

看来我需要forget_bias > 1.0。对于长序列,网络无法使用默认的forget_bias T=50,例如我需要forget_bias = 2.1

关于python - 用于噪声序列的 Tensorflow LSTM,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35264563/

相关文章:

python - 有没有办法检测用户移动的 pygame 显示窗口?

python - BasicRNNCell 偏差没有得到训练

matlab - 绘制两个多元高斯的决策边界

tensorflow - ai平台云预测不起作用,但本地预测起作用

python - Keras/Tensorflow : ValueError: Shape (? ,12) 等级必须为 1

python-3.x - 在 Tensorboard 中使用 Tensorflow v2.0 显示图形

python - GSDMM 聚类的收敛(短文本聚类)

python - 使用 python 请求登录 Facebook

python - Django内存数据库模型创建失败

python - Tensorflow 仅针对变量的某些元素进行最小化