python - Pytorch:具有沿多个轴的张量的索引或一次分散到多个索引

标签 python numpy indexing pytorch

我正在尝试更新 Pytorch 中多维张量的非常具体的索引,但我不确定如何访问正确的索引。我可以在 Numpy 中以一种非常直接的方式做到这一点:

import numpy as np
#set up the array containing the data
data = 100*np.ones((10,10,2))
data[5:,:,:] = 0
#select the data points that I want to update
idxs = np.nonzero(data.sum(2))
#generate the updates that I am going to do
updates = np.random.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates

我需要在 Pytorch 中实现它,但我不确定该怎么做。似乎我需要 scatter 函数,但它只适用于单个维度,而不是我需要的多个维度。我该怎么做?

最佳答案

除了 torch.nonzero 之外,这些操作在 PyTorch 对应项中的工作方式完全相同,默认情况下返回大小为 [z, n] 的张量(其中 z 是非零元素的数量,n 是数量维度)而不是大小为 [z]n 张量元组(就像 NumPy 所做的那样),但是可以通过设置 as_tuple=True< 来改变这种行为.

除此之外,您可以直接将其转换为 PyTorch,但您需要确保类型匹配,因为您不能分配类型为 torch.long 的张量(默认为 torch .randint) 到 torch.float 类型的张量(默认为 torch.ones)。在这种情况下,data 可能是 torch.long 类型:

#set up the array containing the data
data = 100*torch.ones((10,10,2), dtype=torch.long)
data[5:,:,:] = 0
#select the data points that I want to update
idxs = torch.nonzero(data.sum(2), as_tuple=True)
#generate the updates that I am going to do
updates = torch.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates

关于python - Pytorch:具有沿多个轴的张量的索引或一次分散到多个索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62207512/

相关文章:

python - houghcircles - 寻找参数

python - 命令行参数 - 设置运行 Python 代码时的超时限制

python - 花式索引中的多个条件

c++ - 负数或超大 STL 双端队列?

python - 10.3 % 2.5 打印 0.3 或 0.3000000000000007

python - Robot Framework - 访客界面 - 如何获取关键字的关键字子项?

python - 如何将 numpy 数组存储为 tfrecord?

python - 如何在 matlab 和 python/numpy 之间交换多维数组?

oracle - 在空表中执行缓慢的查询。 (删除大量插入后)

mysql - 索引中的列顺序