python - PyTorch Torch.max在多个维度上

标签 python multidimensional-array max pytorch tensor

张量类似:x.shape = [3, 2, 2]

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

我需要在第二个和第三个维度上使用.max()。我希望像这样的[-0.2632, -0.1453, -0.0274]作为输出。我尝试使用:x.max(dim=(1,2)),但这会导致错误。

最佳答案

现在,您可以执行此操作。 PR was merged(8月28日),现在在每晚版本中可用。
只需使用 torch.amax() :

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(torch.amax(x, dim=(1, 2)))

# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])

原始答案
截至今天(2020年4月11日),在PyTorch中无法在多个维度上执行.min().max()。您可以关注它的open issue,看看它是否曾经实现。在您的情况下,一种解决方法是:
import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(x.view(x.size(0), -1).max(dim=-1))

# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))
因此,如果仅需要以下值:x.view(x.size(0), -1).max(dim=-1).values
如果x不是连续的张量,那么.view()将失败。在这种情况下,您应该改为使用.reshape()

更新2020年8月26日
此功能正在PR#43092中实现,其功能将称为aminamax。他们将仅返回值。这可能很快就会合并,因此在阅读本文时,您也许可以在每晚的构建版本上访问这些功能:)玩得开心。

关于python - PyTorch Torch.max在多个维度上,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61156894/

相关文章:

python - pymouse 没有在 python 中与 opencv 一起运行

python - 如何将数据框列拆分为多列

javascript - 在 JavaScript 中查找多维数组中最长的字符串

excel - 未排序集的唯一编号

ruby - 在 ruby​​ 的哈希中找到最低和最高值

python - 第一个python脚本,刮板,欢迎使用建议

python - 如何告诉子进程回答 "y"到命令行问题?

c# - 在 C# 中将多维数组转换为锯齿状数组

perl - 如何将数据推送到多维数组?

hibernate - 你如何获得 "real"sql 与 hibernate 条件查询不同?