python - 在 tensorflow 中使用 bool 张量选择张量的子集

标签 python python-3.x tensorflow

我有两个 2 阶张量 arr1arr2,形状为 m 乘以 n。张量 arr2 是 bool 值;它的每一行中恰好有一个条目是 True。我想提取一个长度为 m 的新 rank-1 张量 arr3,其中 arr3 的第 i 个条目等于 arr1 的第 i 行的条目对应于 arr2 的第 i 行是等于 True

numpy 中,我可以这样做:

arr1 = np.array([[1,2],
                 [3,4]])
arr2 = np.array([[0,1],
                 [1,0]], dtype="bool")
arr3 = arr1[arr2]

我可以在 tensorflow 中做类似的事情吗?我知道我可以 eval() 我的张量,然后使用 numpy 函数,但这似乎效率不高。

question建议使用 tf.gathertf.select,但它不像我的问题那样处理折叠输出的维度。

最佳答案

您可以使用 tf.boolean_mask 吗?

from __future__ import print_function

import tensorflow as tf

with tf.Session() as sess:
    arr1 = tf.constant([[1,2],
                        [3,4]])
    arr2 = tf.constant([[False, True],
                        [True, False]])
    print(sess.run(tf.boolean_mask(arr1, arr2)))

应该给出:[2, 3]

关于python - 在 tensorflow 中使用 bool 张量选择张量的子集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42311773/

相关文章:

python - Django "./manage.py bower install"告诉我 bower 在安装时没有安装

python - Heroku 'DATABASES' 未定义

python - 了解 Sagemaker 对象检测预测的输出

python - 附加到文件无法正常工作

python - 终端输出中的可点击链接

python - Tensorflow 对象检测中的输入张量与 Python 签名不兼容

python - 具有高级 API 的 tensorflow 中的 L2 正则化

python - 子流程变量

python - 获取时间段的值

python - tf.reduce_max 比较不一致