python - 用数组索引 torch 张量

标签 python indexing pytorch tensor torch

我有以下火炬张量:

tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

以及以下 numpy 数组:(如有必要,我可以将其转换为其他内容)
[1 0 1]

我想获得以下张量:
tensor([0.3, -0.5, 0.2])

即我希望 numpy 数组索引张量的每个子元素。最好不使用循环。

提前致谢

最佳答案

您可能想使用 torch.gather - “沿由dim 指定的轴收集值。”

t = torch.tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])
idxs = np.array([1,0,1])

idxs = torch.from_numpy(idxs).long().unsqueeze(1)  
# or   torch.from_numpy(idxs).long().view(-1,1)

t.gather(1, idxs)
tensor([[ 0.3000],
        [-0.5000],
        [ 0.2000]])

在这里,您的索引是 numpy 数组,因此您必须将其转换为 LongTensor。

关于python - 用数组索引 torch 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61311688/

相关文章:

python - 在python中分隔条形图组

java - 通过基于索引检索 Hashmap 元素

sql - Oracle 优化器尴尬地不喜欢使用索引

machine-learning - 计算卷积层的输出大小

python - Pip 在防火墙后不起作用

python - 如何在 django-celery 中使用 .delay() 方法?

python - 使用 selenium 和 python 从 Iframe 获取文本

python - Pandas:删除带有某些日期的字符串

lstm - torch.nn.LSTM 运行时错误

parallel-processing - 如果循环中涉及的所有张量都在 GPU 上,我的 for 循环是否并行运行?