python - np.random.permutation, np.random.choice 的时间表现

标签 python numpy random

在我的纯 Python 图论库中,我遇到了一个函数,其时间性能与可比较的 MATLAB 代码相比非常差,因此我尝试分析该函数中的一些操作。

我跟踪到以下结果

In [27]: timeit.timeit( 'permutation(138)[:4]', setup='from numpy.random import permutation', number=1000000)
Out[27]: 27.659916877746582

将其与 MATLAB 中的性能进行比较

>> tic; for i=1:1000000; randperm(138,4); end; toc
Elapsed time is 4.593305 seconds.

通过将其更改为 np.random.choice 而不是我最初编写的 np.random.permutation,我能够显着提高性能。

In [42]: timeit.timeit( 'choice(138, 4)', setup='from numpy.random import choice', number=1000000)
Out[42]: 18.9618501663208

但它仍然没有接近 matlab 的性能。

是否有另一种方法可以在时间性能接近 MATLAB 时间性能的纯 Python 中获得这种行为?

最佳答案

基于 this solution这展示了如何模拟 np.random.choice(..., replace=False)使用基于 argsort/argpartition 的技巧的行为,您可以重新创建 MATLAB 的 randperm(138,4),即 NumPy 的 np .random.choice(138,4, replace=False)np.argpartition作为:

np.random.rand(138).argpartition(range(4))[:4]

或用np.argsort像这样 -

np.random.rand(138).argsort()[:4]

让我们对这两个版本进行计时,以便与 MATLAB 版本进行性能比较。

在 MATLAB 上 -

>> tic; for i=1:1000000; randperm(138,4); end; toc
Elapsed time is 1.058177 seconds.

在 NumPy 上使用 np.argpartition -

In [361]: timeit.timeit( 'np.random.rand(138).argpartition(range(4))[:4]', setup='import numpy as np', number=1000000)
Out[361]: 9.063489798831142

在 NumPy 上使用 np.argsort -

In [362]: timeit.timeit( 'np.random.rand(138).argsort()[:4]', setup='import numpy as np', number=1000000)
Out[362]: 5.74625801707225

最初建议用 NumPy 的 -

In [363]: timeit.timeit( 'choice(138, 4)', setup='from numpy.random import choice', number=1000000)
Out[363]: 6.793723535243771

似乎可以使用 np.argsort 来提高边际性能。

关于python - np.random.permutation, np.random.choice 的时间表现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35591462/

相关文章:

c# - 获取两个值之间的 n 个不同的随机数,其总和等于给定数

c++ - std::mersenne_twister_engine 和随机数生成

python - 属性错误 : 'module' object has no attribute 'load'

python - 我应该如何对这些元素进行分组,以使总体差异最小化?

python - Anaconda 3 安装错误 - 没有脚本文件夹,也没有 conda 命令提示符和快捷方式

python - 在 Python 中读取 csv 文件时获取 "newline inside string"?

numpy - 为什么 cv2.calcOpticalFlowFarneback 在简单的合成示例上失败?

python - 如何找到第一次出现的 Pandas 数据框值的显着差异?

python - 在同一地址分配的数组 Cython + Numpy

c - 指向循环结构体的指针数组