我想知道是否有办法绕过 tf.where() 输出的数组形状。例如,这是我尝试运行的代码:
import tensorflow as tf
with tf.Session() as sess:
a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]]).eval()
c = tf.where(tf.equal(a,0)).eval()
c = tf.multiply(100,c).eval()
c = tf.add(a,c)
print(c.eval())
我期望的输出是:
[[[ 1,100]
[100, 2]
[[100, 3]
[ 4,100]]]
但是,由于 tf.where() 输出我的代码的方式是 4x3 张量而不是 2x2x2,因此出现错误。是否有另一个命令集可以用来有效地将所有零替换为 100?此方法适用于二维数组。
最佳答案
您可以创建一个 100 的张量,然后在 where
调用中使用它。根据documentation ,这个函数需要两个可选的张量来读取值。
with tf.Session() as sess:
a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]])
h = tf.multiply(tf.ones(a.shape, tf.int32), 100)
c = tf.where(tf.equal(a, 0), h, a)
print(c.eval())
关于python - tf.where() 形状输出的解决方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45045150/