pytorch - 如何在 PyTorch 中移动具有不同偏移量的张量中的列(或行)?

标签 pytorch

在 PyTorch 中,内置 torch.roll函数只能移动具有相同偏移量的列(或行)。但我想移动具有不同偏移量的列。假设输入张量是

[[1,2,3],
 [4,5,6],
 [7,8,9]]
比方说,我想偏移偏移 i对于第 i 列。因此,预期输出是
[[1,8,6],
 [4,2,9],
 [7,5,3]]
这样做的一个选项是使用 torch.roll 单独移动每一列并连接它们中的每一个。但是出于有效性和代码紧凑性的考虑,我不想介绍循环结构。有没有更好的办法?

最佳答案

我对 torch.gather 的性能表示怀疑所以我用 numpy 搜索了类似的问题,找到了 this邮政。
从 NumPy 到 Pytorch 的类似解决方案
我从@Andy L 那里得到了解决方案并将其翻译成 pytorch。但是,请谨慎对待,因为我不知道 strides 是如何工作的:

from numpy.lib.stride_tricks import as_strided
# NumPy solution:
def custom_roll(arr, r_tup):
    m = np.asarray(r_tup)
    arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy`
    #print(arr_roll)
    strd_0, strd_1 = arr_roll.strides
    #print(strd_0, strd_1)
    n = arr.shape[1]
    result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))

    return result[np.arange(arr.shape[0]), (n-m)%n]

# Translated to PyTorch
def pcustom_roll(arr, r_tup):
    m = torch.tensor(r_tup)
    arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].clone() #need `copy`
    #print(arr_roll)
    strd_0, strd_1 = arr_roll.stride()
    #print(strd_0, strd_1)
    n = arr.shape[1]
    result = torch.as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))

    return result[torch.arange(arr.shape[0]), (n-m)%n]
这也是@Daniel M 的即插即用解决方案。
def roll_by_gather(mat,dim, shifts: torch.LongTensor):
    # assumes 2D array
    n_rows, n_cols = mat.shape
    
    if dim==0:
        #print(mat)
        arange1 = torch.arange(n_rows).view((n_rows, 1)).repeat((1, n_cols))
        #print(arange1)
        arange2 = (arange1 - shifts) % n_rows
        #print(arange2)
        return torch.gather(mat, 0, arange2)
    elif dim==1:
        arange1 = torch.arange(n_cols).view(( 1,n_cols)).repeat((n_rows,1))
        #print(arange1)
        arange2 = (arange1 - shifts) % n_cols
        #print(arange2)
        return torch.gather(mat, 1, arange2)
    
基准测试
首先,我在 CPU 上运行这些方法。
令人惊讶的是,gather上面的解决方案是最快的:
n_cols = 10000
n_rows = 100
shifts = torch.randint(-100,100,size=[n_rows,1])
data = torch.arange(n_rows*n_cols).reshape(n_rows,n_cols)
npdata = np.arange(n_rows*n_cols).reshape(n_rows,n_cols)
npshifts = shifts.numpy()
%timeit roll_by_gather(data,1,shifts)
%timeit pcustom_roll(data,shifts)
%timeit custom_roll(npdata,npshifts)
>> 2.41 ms ± 68.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>> 90.4 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>> 247 ms ± 6.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在 GPU 上运行代码显示了类似的结果:
%timeit roll_by_gather(data,shifts)
%timeit pcustom_roll(data,shifts)
131 µs ± 6.79 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3.29 ms ± 46.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
( 注意 :您需要在 torch.arange(...,device='cuda:0') 方法中使用 roll_by_gather)

关于pytorch - 如何在 PyTorch 中移动具有不同偏移量的张量中的列(或行)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66596699/

相关文章:

python - 了解 PyTorch 预测

python - 训练时 Pytorch CUDA OutOfMemory 错误

docker - 可以在没有 GPU 的情况下运行 nvidia-docker 吗?

python - 调试神经网络丢失问题的概率不在 [0,1] 内

python - 如何使用 torch.hub.load 加载本地模型?

python - 如何从 PyTorch 的 ResNet 模型中删除最后一个 FC 层?

python - 如何创建模块列表列表

python - Pytorch和多项式线性回归问题

python - 将one-hot编码维度转换为1的位置索引

python - 通过模型微调获得异常