我按照 documentation 中提到的内容理解 torch.where() 的输出。 但是,我不明白当未给出 x 和 y 时它产生的输出,如下所示(尽管 x 的形状保持不变,但该输出的维度不断变化)。有人可以帮我理解吗?
y = torch.ones(3, 2)
x = torch.randn(3, 2)
print(x)
----------------------------
tensor([[-0.0022, 0.4871],
[ 0.0788, 0.2937],
[ 0.1909, -2.1636]])
----------------------------
print(torch.where(x > 0, x, y))
----------------------------
tensor([[1.0000, 0.4871],
[0.0788, 0.2937],
[0.1909, 1.0000]])
----------------------------
print(torch.where(x > 0))
(tensor([0, 1, 1, 2]), tensor([1, 0, 1, 0]))
最佳答案
此版本的 torch.where
旨在返回令人满意的元素索引。
print(f"Y={torch.where(x > 0)[0].numpy()}")
print(f"X={torch.where(x > 0)[1].numpy()}")
--------------------------------
Y=[0 1 1 2]
X=[1 0 1 0]
在这里您可以更好地看到矩阵中正数的坐标。
关于python - 当 x , y 未给出时如何解释 torch.where() 输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75181924/