python - PyTorch 索引 : select complement of indices

标签 python numpy pytorch matrix-indexing

假设我有一个张量和索引:

x = torch.tensor([1,2,3,4,5])
idx = torch.tensor([0,2,4])

如果我想选择索引中的所有元素,我可以 manually define a Boolean mask像这样:

mask = torch.ones_like(x)
mask[idx] = 0

x[mask]

有更优雅的方法吗?

即我可以直接传递索引而不是创建掩码的语法,例如像这样:

x[~idx]

最佳答案

我找不到令人满意的解决方案来查找多维索引张量的补集,最终实现了我自己的解决方案。它可以在cuda上运行并享受快速的并行计算。

def complement_idx(idx, dim):
    """
    Compute the complement: set(range(dim)) - set(idx).
    idx is a multi-dimensional tensor, find the complement for its trailing dimension,
    all other dimension is considered batched.
    Args:
        idx: input index, shape: [N, *, K]
        dim: the max index for complement
    """
    a = torch.arange(dim, device=idx.device)
    ndim = idx.ndim
    dims = idx.shape
    n_idx = dims[-1]
    dims = dims[:-1] + (-1, )
    for i in range(1, ndim):
        a = a.unsqueeze(0)
    a = a.expand(*dims)
    masked = torch.scatter(a, -1, idx, 0)
    compl, _ = torch.sort(masked, dim=-1, descending=False)
    compl = compl.permute(-1, *tuple(range(ndim - 1)))
    compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
    return compl

示例:

>>> import torch
>>> a = torch.rand(3, 4, 5)
>>> a
tensor([[[0.7849, 0.7404, 0.4112, 0.9873, 0.2937],
         [0.2113, 0.9923, 0.6895, 0.1360, 0.2952],
         [0.9644, 0.9577, 0.2021, 0.6050, 0.7143],
         [0.0239, 0.7297, 0.3731, 0.8403, 0.5984]],

        [[0.9089, 0.0945, 0.9573, 0.9475, 0.6485],
         [0.7132, 0.4858, 0.0155, 0.3899, 0.8407],
         [0.2327, 0.8023, 0.6278, 0.0653, 0.2215],
         [0.9597, 0.5524, 0.2327, 0.1864, 0.1028]],

        [[0.2334, 0.9821, 0.4420, 0.1389, 0.2663],
         [0.6905, 0.2956, 0.8669, 0.6926, 0.9757],
         [0.8897, 0.4707, 0.5909, 0.6522, 0.9137],
         [0.6240, 0.1081, 0.6404, 0.1050, 0.6413]]])
>>> b, c = torch.topk(a, 2, dim=-1)
>>> b
tensor([[[0.9873, 0.7849],
         [0.9923, 0.6895],
         [0.9644, 0.9577],
         [0.8403, 0.7297]],

        [[0.9573, 0.9475],
         [0.8407, 0.7132],
         [0.8023, 0.6278],
         [0.9597, 0.5524]],

        [[0.9821, 0.4420],
         [0.9757, 0.8669],
         [0.9137, 0.8897],
         [0.6413, 0.6404]]])
>>> c
tensor([[[3, 0],
         [1, 2],
         [0, 1],
         [3, 1]],

        [[2, 3],
         [4, 0],
         [1, 2],
         [0, 1]],

        [[1, 2],
         [4, 2],
         [4, 0],
         [4, 2]]])
>>> compl = complement_idx(c, 5)
>>> compl
tensor([[[1, 2, 4],
         [0, 3, 4],
         [2, 3, 4],
         [0, 2, 4]],

        [[0, 1, 4],
         [1, 2, 3],
         [0, 3, 4],
         [2, 3, 4]],

        [[0, 3, 4],
         [0, 1, 3],
         [1, 2, 3],
         [0, 1, 3]]])
>>> al = torch.cat([c, compl], dim=-1)
>>> al
tensor([[[3, 0, 1, 2, 4],
         [1, 2, 0, 3, 4],
         [0, 1, 2, 3, 4],
         [3, 1, 0, 2, 4]],

        [[2, 3, 0, 1, 4],
         [4, 0, 1, 2, 3],
         [1, 2, 0, 3, 4],
         [0, 1, 2, 3, 4]],

        [[1, 2, 0, 3, 4],
         [4, 2, 0, 1, 3],
         [4, 0, 1, 2, 3],
         [4, 2, 0, 1, 3]]])
>>> al, _ = al.sort(dim=-1)
>>> al
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])

关于python - PyTorch 索引 : select complement of indices,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67157893/

相关文章:

nlp - 如何在 PyTorch 中将句子长度批量转换为掩码?

python - Pytorchunet语义分割

python - 将 1x1 稀疏矩阵转换为标量

python - 如何计算 pandas dataframe 列中与另一列匹配的项目数?

python - 同一行中多个循环的顺序

python - 索引可变维数的 ndarray

python - TF.Keras model.predict 比直接 Numpy 慢?

python - 返回数组的函数的 numpy.vectorize

pytorch - 将pytorch模型转换为core-ml时出错

python - 将报告格式转换为Python数据集