张量类似:x.shape = [3, 2, 2]
。
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
我需要在第二个和第三个维度上使用
.max()
。我希望像这样的[-0.2632, -0.1453, -0.0274]
作为输出。我尝试使用:x.max(dim=(1,2))
,但这会导致错误。
最佳答案
现在,您可以执行此操作。 PR was merged(8月28日),现在在每晚版本中可用。
只需使用 torch.amax()
:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(torch.amax(x, dim=(1, 2)))
# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])
原始答案
截至今天(2020年4月11日),在PyTorch中无法在多个维度上执行
.min()
或.max()
。您可以关注它的open issue,看看它是否曾经实现。在您的情况下,一种解决方法是:import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(x.view(x.size(0), -1).max(dim=-1))
# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))
因此,如果仅需要以下值:x.view(x.size(0), -1).max(dim=-1).values
。如果
x
不是连续的张量,那么.view()
将失败。在这种情况下,您应该改为使用.reshape()
。更新2020年8月26日
此功能正在PR#43092中实现,其功能将称为
amin
和amax
。他们将仅返回值。这可能很快就会合并,因此在阅读本文时,您也许可以在每晚的构建版本上访问这些功能:)玩得开心。
关于python - PyTorch Torch.max在多个维度上,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61156894/