python - 如何编译具有可变输入类型的 numba jit'ed 函数?

标签 python random signature optional-parameters numba

假设我有一个函数可以接受 intNone 类型作为输入参数

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out

我希望该函数只返回一个正态分布的随机数。如果我想要可重现的结果,种子应该是 int

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327

如果我想要随机数,seed 应该保留为None。但是,如果我不传递参数(因此种子默认为 None)或显式传递 seed=None,则 numba 会引发 TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)

对于这种情况,如何编写函数,仍然声明签名并使用 nopython 模式?

我的numba版本是0.43.1

最佳答案

第一个问题是 nopython 模式下的 numba 只接受(从版本 0.43.1 开始)np.random.seed: with an integer argument only .

因此,很遗憾,您不能传入 None


第二个问题是(据我所知)没有一个“单一”签名告诉 numba 如何处理缺失值,但是你可以使用两个签名(是的,它非常冗长):

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}

@nb.jit(
    [nb.types.float64(nb.types.misc.Omitted(None)), 
     nb.types.float64(nb.types.int64)], 
    **jitkw)
def get_random(seed=None):
    return np.random.normal()

简单解释一下signaure的两部分:

  • nb.types.float64(nb.types.misc.Omitted(None)) 告诉 numba 在省略参数时使用 None 作为默认类型 <
  • nb.types.float64(nb.types.int64) 是需要整数的签名。

就我个人而言,我不会指定签名,而只是让 numba 自行解决。显式签名在 numba 中很少值得,而且通常情况下它们会导致代码速度变慢且灵 active 降低。

关于python - 如何编译具有可变输入类型的 numba jit'ed 函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55806542/

相关文章:

types - 为模块定义递归签名

c - 如何使函数返回动态分配的二维数组?

python - 按下 ESCape 时如何退出 python2.x 脚本

python - PyPi站 pip 缺少sys,subprocess和timeit软件包

python - 属性错误 : 'gurobipy.LinExpr' object has no attribute '__colno__'

c++ - 如何在C++中调用类的随机函数?

linux - 如何从文件中选择随机行

python - TensorFlow 2.0 如何从 tf.keras.layers 层获取可训练变量,如 Conv2D 或 Dense

random - 线性同余发生器的 3D 表示如何工作?

java - 有没有一种方法可以快速计算字符串的签名以帮助检测字符串更改?