python - tensorflow 。值错误: The two structures don't have the same number of elements

标签 python machine-learning tensorflow nlp lstm

我当前使用raw_rnn实现编码器lstm的代码。这个问题也与我之前问过的另一个问题有关(Tensorflow raw_rnn retrieve tensor of shape BATCH x DIM from embedding matrix)。 当我运行以下代码时,出现以下错误:

ValueError: The two structures don't have the same number of elements.

First structure (1 elements): None

Second structure (2 elements): LSTMStateTuple(c=64, h=64)

错误发生在以下行:encoder_outputs_ta,encoder_final_state,_ = tf.nn.raw_rnn(cell,loop_fn=reader_loop)

import tensorflow as tf
import numpy as np

batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
                         dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)
W = tf.Variable(tf.random_uniform([num_units, vocab_size], -1, 1), dtype=tf.float32, name='W_reader')
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32, name='b_reader')

with tf.variable_scope('ReaderNetwork'):
    def loop_fn_initial():
        init_elements_finished = (0 >= encoder_inputs_length)
        init_input = cell.zero_state(batch_size, dtype=tf.float32)
        init_cell_state = None
        init_cell_output = None
        init_loop_state = None
        return (init_elements_finished, init_input,
                init_cell_state, init_cell_output, init_loop_state)

    def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
        def get_next_input():
            return tf.ones([batch_size, input_embedding_size], dtype=tf.float32)  # TODO replace with value from embeddings

        elements_finished = (time >= encoder_inputs_length)
        finished = tf.reduce_all(elements_finished)  # boolean scalar
        next_input = tf.cond(finished,
                             true_fn=lambda: tf.zeros([batch_size, input_embedding_size], dtype=tf.float32),
        state = previous_state
        output = previous_output
        loop_state = None
        return elements_finished, next_input, state, output, loop_state

    def loop_fn(time, previous_output, previous_state, previous_loop_state):
        if previous_state is None:  # time = 0
            return loop_fn_initial()
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()

def next_batch():
    return {
        encoder_inputs: np.random.random((batch_size, max_time)),
        encoder_inputs_length: [max_time] * batch_size

init = tf.global_variables_initializer()
with tf.Session() as s:
    outs =[outputs], feed_dict=next_batch())
    print len(outs), outs[0].shape



init_input = tf.zeros([batch_size, input_embedding_size], dtype=tf.float32)

init_cell_state = cell.zero_state(batch_size, tf.float32)

def loop_fn_initial():
    init_elements_finished = (0 >= encoder_inputs_length)
    init_input = tf.zeros([batch_size, input_embedding_size], dtype=tf.float32)
    init_cell_state = cell.zero_state(batch_size, tf.float32)
    init_cell_output = None
    init_loop_state = None
    return (init_elements_finished, init_input,
            init_cell_state, init_cell_output, init_loop_state)

关于python - tensorflow 。值错误: The two structures don't have the same number of elements,我们在Stack Overflow上找到一个类似的问题:


python - 如何将 tweepy api 中的坐标流存储到 mysql 数据库中?

scala - 为什么 不起作用,但 在 Spark 中起作用?

apache-spark - Spark MLLib 如何在训练分类器时忽略特征

algorithm - 求解最近邻的最佳性能关键算法


Tensorflow - 损失开始很高并且不会减少

python - 从 Python 中的字符串中提取数字和大小信息(KB、MB 等)

python - 格式字符串中可能遇到单个 '}'

python - Docker & Python,当你的 requirements.pip 列表很大时加速?

TensorFlow: `` 不适用于 Python 3.x 上的字符串