我相信有两种方法可以检查 torch.Tensor
的值是否全部大于 0。使用 .all()
或 all()
,一个最小的可重复示例将说明我的想法:
import torch
walls = torch.tensor([-1, 0, 1, 2])
result1 = (walls >= 0.0).all() # DIFFERENCE WITH BELOW???
result2 = all(walls >= 0.0) # DIFFERENCE WITH ABOVE???
print(result1) # Output: False
print(result2) # Output: False
all()
是内置的,所以我想我更喜欢使用它,但我在互联网上看到的大多数代码都使用 .all()
所以我担心出现意外行为。
他们的行为完全相同吗?
最佳答案
all
是 Python 内置的,这意味着它只能使用极其通用的接口(interface)。在这种情况下,all
将张量视为不透明 iterable 。它通过逐一迭代张量的元素,为每个元素构造一个 Python 对象,然后检查该 Python 对象的真实性。这很慢,而且还增加了几个不必要的低效率层。
相比之下,Tensor.all
知道 Tensor 对象是什么,并且可以直接对其进行操作。它只需要直接扫描张量内部存储即可。没有迭代器协议(protocol)函数调用,没有中间 Python 对象。
Tensor.all
在时间和内存方面总是比内置的 all
更加高效。
关于python - all() 和 .all() 之间的区别,用于检查可迭代是否在任何地方都为 True,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76794391/