Python numba.jit 类型

标签 python numba memoryview

我一整天都在尝试从 numba 文档中推断出类型是如何设置的。我已经有点方法了,但现在我想创建一个函数,它返回一个一维数组和一个二维数组,并获取一堆参数,我很难再进一步:

@jit
class name(object)
    @double[:,:], double[:](double[:], double, double, int64)
    def solve(self, u0, a, b, n):
        self.t = linspace(a, b, n+1)
        dt = abs((b-a)/float(n))
        u = zeros(n+1, len([u0]))
        u[0] = u0
        u = advance(u, t, n, dt)
        return u.transpose(), t.transpose()   

上面抛出了这些异常:

Traceback (most recent call last):
  File "/home/marius/dev/python/inf1100/test_ODE.py", line 2, in <module>
    from DE import *
  File "/home/marius/dev/python/inf1100/DE.py", line 13
    @double[:,:], double[:](double[:], double, double, int64)
           ^
SyntaxError: invalid syntax

如果你能告诉我出了什么问题就更好了,但是如果你能推荐一个文档,一劳永逸地严格解释这些语法就更好了。

感谢您的宝贵时间。

亲切的问候, 马吕斯

最佳答案

这是一个返回元组的方法的简单版本。这适用于在 OS X 上使用 Numba 0.11.1 的我:

import numba
import numpy as np

@numba.jit
class name(object):
    @numba.object_(numba.double[:], numba.double)
    def solve(self, x, a):
        y = np.empty(x.shape[0], dtype=np.float64)
        z = np.empty(x.shape[0], dtype=np.float64)
        for k in xrange(x.shape[0]):
            y[k] = x[k] * a
            z[k] = x[k] + a

        return y, z 

然后使用它:

C = name()
a, b = C.solve(np.arange(5, dtype=np.float64), 3.0)

ab 是:

In [24]: a
Out[24]:
array([  0.,   3.,   6.,   9.,  12.])
In [22]: b
Out[22]:
array([ 3.,  4.,  5.,  6.,  7.])

关于Python numba.jit 类型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20714430/

相关文章:

python - 为什么创建此内存 View 仅在分配给变量时才会引发 ValueError ?

python - 遍历 n 维数组的通用函数

python - 在 anaconda 上更新 python

python - 在 cygwin 中使用 python.ctypes

python - 如何使字符串格式 float 具有固定精度和区域设置类型小数分隔符?

python - Numba `cache=True ` 无效

python - 当我运行Python程序时,pygame窗口打开片刻,然后退出

python - 使用 numba JIT 加速函数时遇到问题

python - 使用 numba 处理和操作 numpy 数组

Cython 编译器因数组崩溃——为什么?