python - tensorflow:LSTM 单元中变量的初始化程序

标签 python tensorflow neural-network lstm recurrent-neural-network

我正在尝试构建一个 RNN 来预测输入数据的情绪是积极还是消极。

tf.reset_default_graph()

input_data = tf.placeholder(tf.int32, [batch_size, 40])
labels = tf.placeholder(tf.int32, [batch_size, 40])

data = tf.Variable(tf.zeros([batch_size, 40, 50]), dtype=tf.float32)
data = tf.nn.embedding_lookup(glove_embeddings_arr, input_data)

lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_units)
lstm_cell = tf.contrib.rnn.DropoutWrapper(cell = lstm_cell, output_keep_prob = 0.75)
value,state = tf.nn.dynamic_rnn(lstm_cell, data, dtype=tf.float32)

weight = tf.Variable(tf.truncated_normal([lstm_units, classes]))
bias = tf.Variable(tf.constant(0.1, shape = [classes]))
value = tf.transpose(value, [1,0,2])
last = tf.gather(value, int(value.get_shape()[0]) - 1)
prediction = (tf.matmul(last, weight) + bias)



true_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels,1))
accuracy = tf.reduce_mean(tf.cast(true_pred,tf.float32))

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))
optimizer = tf.train.AdamOptimizer().minimize(loss)

解释器返回

ValueError: An initializer for variable rnn/basic_lstm_cell/kernel of <dtype: 'string'> is required

有人可以向我解释这个错误吗?

最佳答案

问题在于您(很可能)正在向网络提供原始输入文本。这不在您的代码片段中,但错误指示 <dtype: 'string'> :

ValueError: An initializer for variable rnn/basic_lstm_cell/kernel of <dtype: 'string'> is required

该类型是根据 LSTM 单元获取的输入推导出来的。内部 LSTM 变量( kernelbias )使用默认初始值设定项进行初始化,该初始值设定项(至少现在)只能处理 floating and integer types ,但对于其他类型则失败。在您的情况下,类型是 tf.string ,这就是您看到此错误的原因。

现在,您应该做的是将输入句子转换为真实向量。最好的方法是通过 word embedding ,例如word2vec ,但简单的单词索引也是可能的。看看this post ,特别是它们如何处理文本数据。还有一个完整的工作代码示例。

关于python - tensorflow:LSTM 单元中变量的初始化程序,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46499981/

相关文章:

python - 嵌套 for 循环 - 是否可以每个执行一次然后跳转到下一个对象?

python - 在转换后的 tflite 模型上调用 `allocate_tensors()` 时出现运行时错误

c - FANN 神经网络 - 恒定结果

recursion - 递归函数的网络模拟是什么?

python - Django 测试时未找到 Postgres 函数 json_array_elements

python - 每 20 分钟左右列出一次超出范围的索引错误

python - 在 MultiIndex 中设置级别值

tensorflow - 与 keras 调谐器相关的超参数

python - Windows 10、RTX 2070] : Failed to get convolution algorithm

java - 识别文本文件中的 "figure"模式