python - 如何使用 TensorFlow 中的指定索引访问 3D 张量的元素?

标签 python tensorflow tensorflow2.0

我正在尝试以特定的索引顺序获取 3D 张量的行。以下是输入:

import tensorflow as tf

matrix = tf.constant([
    [[0, 1], [2, 3], [4, 5], [6, 7]], 
    [[8, 9], [10, 11], [12, 13], [14, 15]], 
    [[16, 17], [18, 19], [20, 21], [22, 23]], 
    [[24, 25], [26, 27], [28, 29], [30, 31]], 
    [[32, 33], [34, 35], [36, 37], [38, 39]]
])

indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])

# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]

我正在为 tf.gather_nd() 苦苦挣扎.有什么建议吗?我可以看到它发生在这里,但我不确定 如何在不使用的情况下应用于整个矩阵 for循环或 tf.map_fn
print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())

"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""
编辑:我问了一个关于 numpy 的类似问题。一个聪明的索引答案确实解决了 numpy 版本,但很难将它应用于张量。请随意查看此处接受的答案:How can I get elements from 3D matrix using specified indices in numpy?

最佳答案

呃,那是愚蠢的!在 tensorflow 中已经有一个非常棒的函数可以处理多维数组; tf.gather() 查看 batch_dims更多信息的论据。

>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]

关于python - 如何使用 TensorFlow 中的指定索引访问 3D 张量的元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63431408/

相关文章:

python - 从 list_files 创建 tensorflow 数据集

tensorflow - Model.fit() 是否将整个训练数据集上传到 GPU?

python - 未知层 : KerasLayer when i try to load_model

tensorflow - 如何从 TF Hub 获取 Bert tokenizer 的 vocab 文件

python - qtsql:查询将格式错误的 UTF-8 文本插入 MySQL

python - 如何打印动态变量?

python - 使用 FLask 和 matplotlib 即时生成图像

python - Tensorflow 中密集层的偏差可以设置为零吗?

Tensorflow:将allow_growth设置为true仍然会分配我所有GPU的内存

python - 如何设置随机搜索中使用的指数分布的界限?