假设我有一个张量和索引:
x = torch.tensor([1,2,3,4,5])
idx = torch.tensor([0,2,4])
如果我想选择索引中不的所有元素,我可以 manually define a Boolean mask像这样:
mask = torch.ones_like(x)
mask[idx] = 0
x[mask]
有更优雅的方法吗?
即我可以直接传递索引而不是创建掩码的语法,例如像这样:
x[~idx]
最佳答案
我找不到令人满意的解决方案来查找多维索引张量的补集,最终实现了我自己的解决方案。它可以在cuda上运行并享受快速的并行计算。
def complement_idx(idx, dim):
"""
Compute the complement: set(range(dim)) - set(idx).
idx is a multi-dimensional tensor, find the complement for its trailing dimension,
all other dimension is considered batched.
Args:
idx: input index, shape: [N, *, K]
dim: the max index for complement
"""
a = torch.arange(dim, device=idx.device)
ndim = idx.ndim
dims = idx.shape
n_idx = dims[-1]
dims = dims[:-1] + (-1, )
for i in range(1, ndim):
a = a.unsqueeze(0)
a = a.expand(*dims)
masked = torch.scatter(a, -1, idx, 0)
compl, _ = torch.sort(masked, dim=-1, descending=False)
compl = compl.permute(-1, *tuple(range(ndim - 1)))
compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
return compl
示例:
>>> import torch
>>> a = torch.rand(3, 4, 5)
>>> a
tensor([[[0.7849, 0.7404, 0.4112, 0.9873, 0.2937],
[0.2113, 0.9923, 0.6895, 0.1360, 0.2952],
[0.9644, 0.9577, 0.2021, 0.6050, 0.7143],
[0.0239, 0.7297, 0.3731, 0.8403, 0.5984]],
[[0.9089, 0.0945, 0.9573, 0.9475, 0.6485],
[0.7132, 0.4858, 0.0155, 0.3899, 0.8407],
[0.2327, 0.8023, 0.6278, 0.0653, 0.2215],
[0.9597, 0.5524, 0.2327, 0.1864, 0.1028]],
[[0.2334, 0.9821, 0.4420, 0.1389, 0.2663],
[0.6905, 0.2956, 0.8669, 0.6926, 0.9757],
[0.8897, 0.4707, 0.5909, 0.6522, 0.9137],
[0.6240, 0.1081, 0.6404, 0.1050, 0.6413]]])
>>> b, c = torch.topk(a, 2, dim=-1)
>>> b
tensor([[[0.9873, 0.7849],
[0.9923, 0.6895],
[0.9644, 0.9577],
[0.8403, 0.7297]],
[[0.9573, 0.9475],
[0.8407, 0.7132],
[0.8023, 0.6278],
[0.9597, 0.5524]],
[[0.9821, 0.4420],
[0.9757, 0.8669],
[0.9137, 0.8897],
[0.6413, 0.6404]]])
>>> c
tensor([[[3, 0],
[1, 2],
[0, 1],
[3, 1]],
[[2, 3],
[4, 0],
[1, 2],
[0, 1]],
[[1, 2],
[4, 2],
[4, 0],
[4, 2]]])
>>> compl = complement_idx(c, 5)
>>> compl
tensor([[[1, 2, 4],
[0, 3, 4],
[2, 3, 4],
[0, 2, 4]],
[[0, 1, 4],
[1, 2, 3],
[0, 3, 4],
[2, 3, 4]],
[[0, 3, 4],
[0, 1, 3],
[1, 2, 3],
[0, 1, 3]]])
>>> al = torch.cat([c, compl], dim=-1)
>>> al
tensor([[[3, 0, 1, 2, 4],
[1, 2, 0, 3, 4],
[0, 1, 2, 3, 4],
[3, 1, 0, 2, 4]],
[[2, 3, 0, 1, 4],
[4, 0, 1, 2, 3],
[1, 2, 0, 3, 4],
[0, 1, 2, 3, 4]],
[[1, 2, 0, 3, 4],
[4, 2, 0, 1, 3],
[4, 0, 1, 2, 3],
[4, 2, 0, 1, 3]]])
>>> al, _ = al.sort(dim=-1)
>>> al
tensor([[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]])
关于python - PyTorch 索引 : select complement of indices,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67157893/