我是 Numba 的新手,我需要使用 Numba 来加速一些 Pytorch 功能。但我发现即使是一个非常简单的功能也不起作用:(
import torch
import numba
@numba.njit()
def vec_add_odd_pos(a, b):
res = 0.
for pos in range(len(a)):
if pos % 2 == 0:
res += a[pos] + b[pos]
return res
x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
但是出现以下错误def vec_add_odd_pos(a, b):
资源 = 0。
^
此错误可能是由以下参数引起的:
谁能帮我?包含更多示例的链接也将不胜感激。谢谢。
最佳答案
正如其他人所提到的,numba 目前不支持火炬张量,只支持 numpy 张量。然而有TorchScript ,它有一个类似的目标。然后可以将您的函数重写为:
import torch
@torch.jit.script
def vec_add_odd_pos(a, b):
res = 0.
for pos in range(len(a)):
if pos % 2 == 0:
res += a[pos] + b[pos]
return res
x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
请注意:虽然您说您的代码片段只是一个简单的示例,但 for 循环确实很慢并且运行 TorchScript 可能对您没有太大帮助,您应该不惜一切代价避免它们,并且只有在不存在其他解决方案时才使用它们。话虽如此,以下是如何以更高效的方式实现您的功能:def vec_add_odd_pos(a, b):
evenids = torch.arange(len(a)) % 2 == 0
return (a[evenids] + b[evenids]).sum()
关于pytorch - 如何将 Numba 用于 Pytorch 张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63169760/