python - 在 PyTorch 中有条件地应用张量运算

标签 python python-3.x pytorch tensor

我知道 PyTorch 没有类似于 map 的函数来将函数应用于张量的每个元素。那么,如果 PyTorch 中没有类似 map 的函数,我可以做类似下面的事情吗?

if tensor_a * tensor_b.matmul(tensor_c) < 1:
    return -tensor_a*tensor_b
else:
    return 0

如果张量是一维的,这将起作用。但是,当 tensor_b 是 2D 时,我需要它才能工作(tensor_a 需要在 return 语句中被 unsqueezed ).这意味着应该返回一个二维张量,其中一些行将是 0 向量。

很高兴使用最新 Python 版本的最新功能。

最佳答案

如果我理解正确,您希望以任何一种方式(因此映射)返回一个张量,但是通过逐元素检查条件。假设 tensor_atensor_btensor_c 的形状都是二维的,如“简单矩阵”,这是一个可能的解决方案。

你要找的可能是torch.where ,它非常接近基于条件的映射,它将返回一个值或另一个元素方面

它的工作方式类似于 torch.where(condition, value_if, value_else),其中所有三个张量具有相同的形状(value_ifvalue_else 可以实际上是将被转换为张量的 float ,填充相同的值)。此外,condition 是一个 bool 张量,它定义了分配给输出张量的值:它是一个 bool 掩码。

为了这个例子的目的,我使用了随机张量:

>>> a = torch.rand(2, 2, dtype=float)*100
>>> b = torch.rand(2, 2, dtype=float)*0.01
>>> c = torch.rand(2, 2, dtype=float)*10

>>> torch.where(a*(b@c) < 1, -a*b, 0.)
tensor([[ 0.0000,  0.0000],
        [ 0.0000, -0.0183]], dtype=torch.float64)

更一般地说,如果 tensor_atensor_b 的形状为 (m, n),并且 tensor_c,这将起作用由于操作限制, 的形状为 (n, m)。在您的实验中,我猜您只有列。

关于python - 在 PyTorch 中有条件地应用张量运算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65854970/

相关文章:

python - 如何在 reportlab Canvas 中绘制 matplotlib 图?

Python - 如何加速具有函数和多个返回值的嵌套 for 循环

python - 尝试在 Pytorch 中加载自定义数据集

pytorch - 使用 pytorch Lightning 在回调中访问纪元结束时的所有批处理输出

python - 为什么 conv2d 在不同的批量大小下会产生不同的结果

python - numpy/scipy/ipython :Failed to interpret file as a pickle

python - 检查 OrientDB 中的类创建

python - 从文件中读取行时继续使用值,直到列表中的下一个键 - python

python - 检查字符串是否包含列表元素

python-3.x - 大小不匹配,m1 : [3584 x 28], m2 : [784 x 128] at/pytorch/aten/src/TH/generic/THTensorMath. cpp:940