我知道 PyTorch 没有类似于 map
的函数来将函数应用于张量的每个元素。那么,如果 PyTorch 中没有类似 map
的函数,我可以做类似下面的事情吗?
if tensor_a * tensor_b.matmul(tensor_c) < 1:
return -tensor_a*tensor_b
else:
return 0
如果张量是一维的,这将起作用。但是,当 tensor_b
是 2D 时,我需要它才能工作(tensor_a
需要在 return
语句中被 unsqueeze
d ).这意味着应该返回一个二维张量,其中一些行将是 0
向量。
很高兴使用最新 Python 版本的最新功能。
最佳答案
如果我理解正确,您希望以任何一种方式(因此映射)返回一个张量,但是通过逐元素检查条件。假设 tensor_a
、tensor_b
和 tensor_c
的形状都是二维的,如“简单矩阵”,这是一个可能的解决方案。
你要找的可能是torch.where
,它非常接近基于条件的映射,它将返回一个值或另一个元素方面。
它的工作方式类似于 torch.where(condition, value_if, value_else)
,其中所有三个张量具有相同的形状(value_if
和 value_else
可以实际上是将被转换为张量的 float ,填充相同的值)。此外,condition
是一个 bool 张量,它定义了分配给输出张量的值:它是一个 bool 掩码。
为了这个例子的目的,我使用了随机张量:
>>> a = torch.rand(2, 2, dtype=float)*100
>>> b = torch.rand(2, 2, dtype=float)*0.01
>>> c = torch.rand(2, 2, dtype=float)*10
>>> torch.where(a*(b@c) < 1, -a*b, 0.)
tensor([[ 0.0000, 0.0000],
[ 0.0000, -0.0183]], dtype=torch.float64)
更一般地说,如果 tensor_a
和 tensor_b
的形状为 (m, n)
,并且 tensor_c,这将起作用由于操作限制,
的形状为 (n, m)
。在您的实验中,我猜您只有列。
关于python - 在 PyTorch 中有条件地应用张量运算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65854970/