tensorflow - 尝试使用 tf.scan() 实现循环网络

标签 tensorflow

我正在尝试使用 tf.scan 实现循环状态张量。我目前的代码是这样的:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3


def iterate_state(prev_state_tuple, input):
    with tf.name_scope('h1'):
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        matmuladd = tf.matmul(inputs, weights) + biases
        unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
        prev_state = unpacked_state
        state = 0.9* prev_state + 0.1*matmuladd
        output = tf.nn.relu(state)
        return tf.concat(0,[state, output])

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    with tf.variable_scope('states'):
        initial_state = tf.zeros([HIDDEN_1],
                                 name='initial_state')
        initial_out = tf.zeros([HIDDEN_1],
                                 name='initial_out')
        concat_tensor = tf.concat(0,[initial_state, initial_out])
        states, output = tf.scan(iterate_state, inputs,
                                     initializer=concat_tensor, name='states')

    sess = tf.Session()
    # Run the Op to initialize the variables.
    sess.run(tf.initialize_all_variables())
    iter_ = data_iter()
    for i in xrange(0, 2):
        print ("iteration: ",i)
        input_data = iter_.next()
        out,st = sess.run([output,states], feed_dict={ inputs: input_data})

但是,运行时出现此错误:

Traceback (most recent call last):
  File "cycles_in_graphs_with_scan.py", line 37, in <module>
    initializer=concat_tensor, name='states')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
(tensorflow)charlesq@Leviathan ~/projects/stuff $ python cycles_in_graphs_with_scan.py 
Traceback (most recent call last):
  File "cycles_in_graphs_with_scan.py", line 37, in <module>
    initializer=concat_tensor, name='states')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

我已经尝试过 pack/unpackconcat/split 但我遇到了同样的错误。

有什么想法可以解决这个问题吗?

最佳答案

您收到错误是因为 tf.scan()返回一个单个 tf.Tensor,因此该行:

states, output = tf.scan(...)

...无法将从 tf.scan() 返回的张量解构(解包)为两个值(statesoutputs)。实际上,代码尝试将 tf.scan() 的结果视为长度为 2 的列表,并将第一个元素分配给 states,将第二个元素分配给 output,但是与 Python 列表或元组不同,tf.Tensor 不支持此功能。

相反,您需要手动从 tf.scan() 结果中提取值。例如,使用 tf.split() :

scan_result = tf.scan(...)
# Assumes values are packed together along `split_dim`.
states, output = tf.split(split_dim, 2, scan_result)

或者,您可以使用 tf.slice()tf.unpack()提取相关的状态输出值。

关于tensorflow - 尝试使用 tf.scan() 实现循环网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37599327/

相关文章:

Tensorflow 存储学习

algorithm - TensorFlow:它只有 SGD 算法吗?或者它是否也有其他像 LBFGS

javascript - 将 TensorFlow JS model.predict 的值获取到变量中

python - 关于 .shuffle、.batch 和 .repeat 的 Tensorflow 数据集问题

python - 使用输入管道时如何将数据注入(inject)图形?

tensorflow - 使用 tf.control_dependencies 的相同代码的不同结果

tensorflow - 有没有办法检查 mxnet 是否使用我的 gpu?

python - 属性错误:int object has no attribute name

tensorflow - Keras - 有状态 LSTM 与无状态 LSTM

tensorflow - 停止和启动深度学习谷歌云 VM 实例导致 tensorflow 停止识别 GPU