我想编写一个函数来实现 this question 中描述的行为.
也就是说,我想将 PyTorch 中矩阵每一行中的重复值清零。例如,给定一个矩阵
torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]])
我想得到
torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])
或
torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 4, 0, 0]])
根据链接的问题,仅 torch.unique()
是不够的。我想知道如何在没有循环的情况下实现这个功能。
最佳答案
x = torch.tensor([
[1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]
], dtype=torch.long)
# sorting the rows so that duplicate values appear together
# e.g., first row: [1, 2, 3, 3, 3, 4, 4]
y, indices = x.sort(dim=-1)
# subtracting, so duplicate values will become 0
# e.g., first row: [1, 2, 3, 0, 0, 4, 0]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()
# retrieving the original indices of elements
indices = indices.sort(dim=-1)[1]
# re-organizing the rows following original order
# e.g., first row: [1, 2, 3, 4, 0, 0, 0]
result = torch.gather(y, 1, indices)
print(result) # => output
输出
tensor([[1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])
关于python - 如何将 PyTorch 张量的每一行中的重复值清零?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62300404/