python - Numba 和泊松分布的随机数

标签 python python-2.7 numpy random numba

我发现模拟中的瓶颈之一是根据泊松分布生成随机数。我原来的代码是这样的

import numpy as np
#Generating some data. In the actual code this comes from the previous
#steps in the simulation. But this gives an example of the type of data
n = 5000000
pop_n = np.array([range(500000)])

pop_n[:] = np.random.poisson(lam=n*pop_n/np.sum(pop_n))

现在,我读到 numba 可以非常简单地提高速度。我定义了这个函数

from numba import jit

@jit()
def poisson(n, pop_n, np=np):
    return np.random.poisson(lam=n*pop_n/np.sum(pop_n))

这个确实比原来的运行得更快。然而,我尝试走得更远:)当我写的时候

@jit(nopython=True)
def poisson(n, pop_n, np=np):
    return np.random.poisson(lam=n*pop_n/np.sum(pop_n))

我得到了

Failed at nopython (nopython frontend)
Invalid usage of Function(np.random.poisson) with parameters     (array(float64, 1d, C))
Known signatures:
 * (float64,) -> int64
 * () -> int64
 * parameterized

一些问题 为什么会发生此错误以及如何修复它。

还有更好的优化吗?

最佳答案

Numba 不支持数组作为 np.random.poissonlam 参数,因此您必须自己执行循环:

import numba as nb
import numpy as np

@nb.njit
def poisson(n, pop_n):
    res = np.empty_like(pop_n)
    pop_n_sum = np.sum(pop_n)
    for idx, item in enumerate(range(pop_n.shape[0])):
        res[idx] = np.random.poisson(n*pop_n[idx] / pop_n_sum)
    return res

n = 5000000
pop_n = np.array(list(range(1, 500000)), dtype=float)
poisson(n, pop_n)

但根据我的计时,这与使用纯 NumPy 一样快:

%timeit poisson(n, pop_n)
# 203 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.random.poisson(lam=n*pop_n/np.sum(pop_n))
# 203 ms ± 3.97 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

这是因为,尽管 Numba 支持 np.random.poisson 和诸如 np.sum 之类的函数,但这些功能只是为了方便而支持,实际上并没有加快代码速度(很多) 。它可能可以在某种程度上避免函数调用开销,但考虑到它只会在纯 Python 中调用 np.random.poisson 一次,这并不算多(与创建 50 万个随机数相比,完全可以忽略不计) )。

如果您想加快一个无法用纯 NumPy 完成的循环,Numba 的速度非常快,但您不应该指望 numba(或其他任何东西)可以提供相当于同等速度的主要加速NumPy 函数。如果可以轻松地让它们更快 - NumPy 开发人员也会让它更快。 :)

关于python - Numba 和泊松分布的随机数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45548951/

相关文章:

Python - 批量选择然后从一个数据库插入到另一个数据库

python - 使用 python 编辑 outlook 分发列表

python - 只有整数、切片 (`:` )、省略号 (`...` )、numpy.newaxis (`None` ) 和整数数组是有效索引

python - 将列表从数据帧转换为 numpy 数组

python - 如何使用lark ebnf解析字符串内的~{expr}

python - 从 web.py 中解压参数列表并实例化 WTForms 对象

python - 调用保存在类属性中的函数 : different behavior with built-in function vs. 普通函数

python - pandas 将字符串列转换为日期时间,允许丢失但不无效

python - 素性测试比暴力法花费的时间更长,我该如何改进?

python - 使用python识别垃圾unicode字符串