python - tf.where() 形状输出的解决方法

标签 python tensorflow

我想知道是否有办法绕过 tf.where() 输出的数组形状。例如,这是我尝试运行的代码:

import tensorflow as tf

with tf.Session() as sess:
    a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]]).eval()
    c = tf.where(tf.equal(a,0)).eval()
    c = tf.multiply(100,c).eval()
    c = tf.add(a,c)

    print(c.eval())

我期望的输出是:

[[[  1,100]
  [100,  2]

 [[100,  3]
  [  4,100]]]

但是,由于 tf.where() 输出我的代码的方式是 4x3 张量而不是 2x2x2,因此出现错误。是否有另一个命令集可以用来有效地将所有零替换为 100?此方法适用于二维数组。

最佳答案

您可以创建一个 100 的张量,然后在 where 调用中使用它。根据documentation ,这个函数需要两个可选的张量来读取值。

with tf.Session() as sess:
    a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]])
    h = tf.multiply(tf.ones(a.shape, tf.int32), 100)
    c = tf.where(tf.equal(a, 0), h, a)

    print(c.eval())

关于python - tf.where() 形状输出的解决方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45045150/

相关文章:

python - 如何防止 matplotlib 绘制相似值数组的浮点变化?

Python:识别具有最大值(value)的列

python - 如何在不使用numpy或tensorflow中的任何循环的情况下获取矩阵行?

python - 将 TFRecords 转换回 JPEG 图像

javascript - TensorFlow.js 和复杂的数据集?

tensorflow - TensorFlow 中 getter 的概念

python - 如何在python中返回通过sklearn的函数KernelDensity估计的分布的平均值(或期望值)?

python - .INI 文件中一个标题下的键、值

python - 如何在numpy中沿轴相加

TensorFlow 跨设备通信