python - 在二维 tf.Variable 中使用 tf.scatter_update

标签 python matrix tensorflow

我正在关注这个 Manipulating matrix elements in tensorflow .使用 tf.scatter_update。但我的问题是: 如果我的 tf.Variable 是二维的会怎样?比方说:

a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])

例如,我如何更新每行的第一个元素并为其分配值 1?

我试过类似的东西

for line in range(2):
    sess.run(tf.scatter_update(a[line],[0],[1]))

但它失败了(我预料到了)并给我错误:

TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input

我该如何解决这类问题?

`

最佳答案

在 tensorflow 中,您不能更新张量,但可以更新变量。

scatter_update 运算符只能更新变量的第一个维度。 您必须始终将引用张量传递给散点更新(a 而不是 a[line])。

这是更新变量第一个元素的方法:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
    b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]])

with tf.Session(graph=g) as sess:
   sess.run(tf.initialize_all_variables())
   print sess.run(a)
   print sess.run(b)

输出:

[[0 0 0 0]
 [0 0 0 0]]
[[1 0 0 0]
 [1 0 0 0]]

但是必须再次更改整个张量,分配一个全新的张量可能会更快。

关于python - 在二维 tf.Variable 中使用 tf.scatter_update,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40320585/

相关文章:

python - pymc3 : Multiple observed values

java - 我的类在 jar 中找不到 python 文件

arrays - 如何在Octave/MATLAB中找到向量每两个对应元素之间的范围?

c++ - 使用 Rcpp 查找重复项

python - tensorflow 2.0 自定义训练循环的学习率

python - 无法在 tensorflow 中加载音频文件(Windows10)

python - 使用 keras 模型中的 tensorflow 图进行预测

python - 创建+读取+追加+二进制的文件模式

python - MAC 上的 AWS CLI 指向错误的 Python 版本

python - 从文本文件添加两个矩阵(Python 无模块)