python - 在 tensorflow 中重新分配非变量张量

标签 python tensorflow

我有一个要求,我想使用 x 的更新值作为 RNN 的输入。下面的代码片段可能会详细说明您。

x = tf.placeholder("float", shape=[None,1])
RNNcell = tf.nn.rnn_cell.BasicRNNCell(....)
outputs, _ = tf.dynamic_rnn(RNNCell, tf.reshape(x, [1,-1,1]))
x = outputs[-1] * (tf.Varaibles(...) * tf.Constants(...)) 

最佳答案

@Vlad 的答案是正确的,但由于我是新成员,所以无法投票。下面的代码片段是带有 RNN 单元的 Vlads 的更新版本。

x = tf.placeholder("float", shape=[None,1])
model = tf.nn.rnn_cell.BasicRNNCell(num_units=1, activation=None)

outputs, state = tf.nn.dynamic_rnn(model, tf.reshape(x, [-1,1, 1]), dtype=tf.float32)

# output1 = model.output 

# output1 = outputs[-1]
output1 = outputs[:,-1,:]
# output1 = outputs

some_value = tf.constant([9.0],     # <-- Some tensor the output will be multiplied by
                         dtype=tf.float32)
output1 *= some_value                # <-- The output had been multiplied by `some_value`
                                     #     (with broadcasting in case of
                                     #     more than one input samples)

with tf.control_dependencies([output1]): # <-- Not necessary, but explicit control
    output2, state2 = model(output1,state)  

关于python - 在 tensorflow 中重新分配非变量张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55700024/

相关文章:

python - Pyspark:在驱动程序和工作人员上使用 ffmpeg

python docopt : "expected string or buffer"

python - 如何删除二维数组中具有特定值的行并修改列值?

python - 如何查询 arXiv 特定年份?

python - 将 Tensorflow 预测导出到 csv,但结果包含全零 - 这是因为 one-hot 结局吗?

python - 如何打印 jupyter 中的所有可用内存

tensorflow - 相当于 PyTorch 中的 tf.linalg.diag_part

tensorflow - 从 RNN 单元获取权重和偏差值

python - InvalidArgumentError : tensorflow logits and labels still not broadcastable

python - split() 和 tokenize() 之间的区别