python - 由另一个多维张量索引多维 torch 张量

标签 python numpy pytorch tensor numpy-slicing

我在 pytorch 中有一个张量 x 比方说形状 (5,3,2,6) 和另一个形状 (5,3,2,1) 的张量 idx,其中包含第一个张量中每个元素的索引。我想用第二个张量的索引对第一个张量进行切片。我试过 x= x[idx] 但是当我真的希望它的形状为 (5,3,2) 或 (5,3,2,1) 时,我得到了一个奇怪的维度。
我会尝试举一个更简单的例子:
让我们说

x=torch.Tensor([[10,20,30],
                 [8,4,43]])
idx = torch.Tensor([[0],
                    [2]])
我想要类似的东西
y = x[idx]
'y' 输出 [[10],[43]]或类似的东西。
索引表示最后一维所需元素的位置。对于上面的示例,其中 x.shape = (2,3) 最后一个维度是列,然后 'idx' 中的索引是列。我想要这个,但超过 2 个维度

最佳答案

从我从评论中了解到,您需要 idx最后一个维度中的索引和 idx 中的每个索引对应于 x 中的类似索引(除了最后一个维度)。在这种情况下(这是 numpy 版本,您可以将其转换为火炬):

ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]
输出:
[[10]
 [43]]

关于python - 由另一个多维张量索引多维 torch 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62762769/

相关文章:

python - 在 SQLAlchemy 中合并多个声明性基础

python - 如何在 Python 中查询输入而不输出新行(续)

python - 从 Pandas 数据帧计算 RSI

python - 如何根据模糊条件从Numpy数组中选择值?

python - PyTorch 的 `no_grad` 函数在 TensorFlow/Keras 中的等价物是什么?

python 程序不打印所有 pickle 文件数据

python - Theano/numpy 高级索引

python - numpy 有条件地用数组替换标量/ bool 值

machine-learning - 使用 PyTorch 预测网格坐标序列

python - 创建变压器无峰值掩码时如何修复 numpy 中的 "TypeError: data type not understood"