machine-learning - tensorflow 中 BasicRNNCell 的 input_size 是多少?

标签 machine-learning tensorflow neural-network deep-learning recurrent-neural-network

根据BasicRNNCell的文档:

__call__(
    inputs,
    state,
    scope=None)

Args:
inputs: 2-D tensor with shape [batch_size x input_size].

似乎input_size在不同的运行中可能不同?据我对RNN的了解,input_size决定了形状(input_size,hidden_​​state_size)的内部权重矩阵W_x,并且它应该是一致的。如果我交替使用 input_size=3input_size=4 运行此单元会怎样?

最佳答案

inputs 是一个二维张量:[batch_size x input_size]

你是对的,input_size必须对应于RNN单元的num_units。但是 batch_size 可以变化,并且只需与调用的另一个参数 state 相对应。

试试这个代码:

import tensorflow as tf
from tensorflow.contrib.rnn import BasicRNNCell

dim = 10
x = tf.placeholder(tf.float32, shape=[None, dim])
y = tf.placeholder(tf.float32, shape=[4, dim])
z = tf.placeholder(tf.float32, shape=[None, dim + 1])
print('x, y, z:', x.shape, y.shape, z.shape)

cell = BasicRNNCell(dim)
state1 = cell.zero_state(batch_size=4, dtype=tf.float32)
state2 = cell.zero_state(batch_size=8, dtype=tf.float32)

out1, out2 = cell(x, state1)
print(out1.shape, out2.shape)

out1, out2 = cell(x, state2)
print(out1.shape, out2.shape)

out1, out2 = cell(y, state1)
print(out1.shape, out2.shape)

这是输出:

x, y, z: (?, 10) (4, 10) (?, 11)
(4, 10) (4, 10)
(8, 10) (8, 10)
(4, 10) (4, 10)

此单元格接受具有两种状态的 x,以及具有 state1y,并且不接受具有任何状态的 z状态。以下两个调用都会导致错误:

out1, out2 = cell(y, state2)     # ERROR: dimensions mismatch
print(out1.shape, out2.shape)

out1, out2 = cell(z, state1)     # ERROR: dimensions mismatch
print(out1.shape, out2.shape)

关于machine-learning - tensorflow 中 BasicRNNCell 的 input_size 是多少?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47034317/

相关文章:

tensorflow - 值错误 : Operation u'tpu_140462710602256/VarIsInitializedOp' has been marked as not fetchable

python - 如何用包含训练数据的矩阵替换输入数据

python - 使用我的数据而不是 20 个新闻组进行情绪分析

parsing - 从句子中找到有意义的子句子

python - Tensorflow:动态进行字母预测

machine-learning - 在扩展特征空间中,核 SVM 与线性 SVM 相比有哪些缺点?

python - TensorFlow python循环 "for"性能

带有 HDFS 的 TensorFlow 数据集 API

machine-learning - epsilon 超参数如何影响 tf.train.AdamOptimizer?

python - 使用 alexnet 和 flow from 目录来训练灰度数据集