python - 如何将 PyTorch 张量的每一行中的重复值清零?

标签 python numpy pytorch torch

我想编写一个函数来实现 this question 中描述的行为.

也就是说,我想将 PyTorch 中矩阵每一行中的重复值清零。例如,给定一个矩阵

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])

我想得到

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])

根据链接的问题,仅 torch.unique() 是不够的。我想知道如何在没有循环的情况下实现这个功能。

最佳答案

x = torch.tensor([
    [1, 2, 3, 4, 3, 3, 4],
    [1, 6, 3, 5, 3, 5, 4]
], dtype=torch.long)

# sorting the rows so that duplicate values appear together
# e.g., first row: [1, 2, 3, 3, 3, 4, 4]
y, indices = x.sort(dim=-1)

# subtracting, so duplicate values will become 0
# e.g., first row: [1, 2, 3, 0, 0, 4, 0]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()

# retrieving the original indices of elements
indices = indices.sort(dim=-1)[1]

# re-organizing the rows following original order
# e.g., first row: [1, 2, 3, 4, 0, 0, 0]
result = torch.gather(y, 1, indices)

print(result) # => output

输出

tensor([[1, 2, 3, 4, 0, 0, 0],
        [1, 6, 3, 5, 0, 0, 4]])

关于python - 如何将 PyTorch 张量的每一行中的重复值清零?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62300404/

相关文章:

python - 训练CNN时出错: "RuntimeError: The size of tensor a (10) must match the size of tensor b (64) at non-singleton dimension 1"

python - 在 pytorch 中使用双线性插值移动图像

python - 为什么 pandas 按位置索引子集会出错?

python - 卸载应用程序 A,因为应用程序 B 在一次旧迁移中具有依赖项

python - 从 python 字典中打印列

Python 绘制 for 循环内的 for 循环生成的数据

python - 将不同的列加入其中之一 - python

python - Pytorch几何: RuntimeError: expected scalar type Long but found Float

python - 如何使用 Pandas 创建一个新列来识别时间字段中的接近度?

python - Python 中带键的 sort_values()