根据BasicRNNCell
的文档:
__call__(
inputs,
state,
scope=None)
Args:
inputs: 2-D tensor with shape [batch_size x input_size].
似乎input_size
在不同的运行中可能不同?据我对RNN的了解,input_size
决定了形状(input_size,hidden_state_size)
的内部权重矩阵W_x,并且它应该是一致的。如果我交替使用 input_size=3
和 input_size=4
运行此单元会怎样?
最佳答案
inputs
是一个二维张量:[batch_size x input_size]
。
你是对的,input_size
必须对应于RNN单元的num_units
。但是 batch_size
可以变化,并且只需与调用的另一个参数 state
相对应。
试试这个代码:
import tensorflow as tf
from tensorflow.contrib.rnn import BasicRNNCell
dim = 10
x = tf.placeholder(tf.float32, shape=[None, dim])
y = tf.placeholder(tf.float32, shape=[4, dim])
z = tf.placeholder(tf.float32, shape=[None, dim + 1])
print('x, y, z:', x.shape, y.shape, z.shape)
cell = BasicRNNCell(dim)
state1 = cell.zero_state(batch_size=4, dtype=tf.float32)
state2 = cell.zero_state(batch_size=8, dtype=tf.float32)
out1, out2 = cell(x, state1)
print(out1.shape, out2.shape)
out1, out2 = cell(x, state2)
print(out1.shape, out2.shape)
out1, out2 = cell(y, state1)
print(out1.shape, out2.shape)
这是输出:
x, y, z: (?, 10) (4, 10) (?, 11)
(4, 10) (4, 10)
(8, 10) (8, 10)
(4, 10) (4, 10)
此单元格接受具有两种状态的 x
,以及具有 state1
的 y
,并且不接受具有任何状态的 z
状态。以下两个调用都会导致错误:
out1, out2 = cell(y, state2) # ERROR: dimensions mismatch
print(out1.shape, out2.shape)
out1, out2 = cell(z, state1) # ERROR: dimensions mismatch
print(out1.shape, out2.shape)
关于machine-learning - tensorflow 中 BasicRNNCell 的 input_size 是多少?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47034317/