有没有一种简单的方法可以将 PyTorch 张量的对角线置零?
例如我有:
tensor([[2.7183, 0.4005, 2.7183, 0.5236],
[0.4005, 2.7183, 0.4004, 1.3469],
[2.7183, 0.4004, 2.7183, 0.5239],
[0.5236, 1.3469, 0.5239, 2.7183]])
我想得到:
tensor([[0.0000, 0.4005, 2.7183, 0.5236],
[0.4005, 0.0000, 0.4004, 1.3469],
[2.7183, 0.4004, 0.0000, 0.5239],
[0.5236, 1.3469, 0.5239, 0.0000]])
最佳答案
我相信最简单的方法是使用 torch.diagonal
:
z = torch.randn(4,4)
torch.diagonal(z, 0).zero_()
print(z)
>>> tensor([[ 0.0000, -0.6211, 0.1120, 0.8362],
[-0.1043, 0.0000, 0.1770, 0.4197],
[ 0.7211, 0.1138, 0.0000, -0.7486],
[-0.5434, -0.8265, -0.2436, 0.0000]])
这样,代码就非常明确了,您将性能委托(delegate)给了 pytorch 的内置函数。
关于python - PyTorch 张量的零对角线?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65712349/