我需要编写一段代码来了解张量的任何条目是否具有特定值“2”。
这是我用于测试的代码:
sess = tf.Session()
some_values = tf.constant([1,2,3,4], dtype=tf.int32)
values_equal_two = (some_values == 2 )
print(sess.run(values_equal_two))
这是我收到的错误:
TypeError: Fetch argument False has invalid type <class 'bool'>, must be a
string or Tensor. (Can not convert a bool into a Tensor or Operation.)
令人惊讶的是,如果我将 == 运算符更改为 >=,如下所示:
sess = tf.Session()
some_values = tf.constant([1, 2, 3, 4], dtype=tf.int32)
values_equal_two = (some_values >= 2)
print(sess.run(values_equal_two))
它工作正常,并返回:
[False True True True]
我想知道问题可能是什么,或者是否可以用另一种方式完成相同的任务。预先感谢您的任何建议。
最佳答案
>=
运算符按预期工作而 ==
不按预期工作的原因是,__ge__
python 方法已经在 TensorFlow Python API 中重载,而 __eq__
尚未重载(查看此 answer )。
如果你想检查是否相等,可以使用 tf.equal
支持广播:
sess = tf.Session()
some_values = tf.constant([1, 2, 3, 4], dtype=tf.int32)
values_equal_two = tf.equal(some_values, 2)
print(sess.run(values_equal_two))
打印[False True False False]
。
关于python - 张量比较未按预期进行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45660585/