python - 改善 python numpy 代码的运行时间

标签 python arrays performance numpy numba

我有一个代码可以将 bin 重新分配给一个大的 numpy 数组。基本上,大型数组的元素以不同的频率进行采样,最终目标是将整个数组重新组合到固定的 bin freq_bins 中。对于我拥有的数组来说,代码有点慢。有什么好的方法可以提高这段代码的运行时间吗?现在只需几个因素就可以了。也许一些numba魔法就可以了。

import numpy as np
import time
division = 90
freq_division = 50
cd = 3000
boost_factor = np.random.rand(division, division, cd)
freq_bins = np.linspace(1, 60, freq_division)
es = np.random.randint(1,10, size = (cd, freq_division))
final_emit = np.zeros((division, division, freq_division))
time1 = time.time()
for i in xrange(division):
    fre_boost = np.einsum('ij, k->ijk', boost_factor[i], freq_bins)
    sky_by_cap = np.einsum('ij, jk->ijk', boost_factor[i],es)
    freq_index = np.digitize(fre_boost, freq_bins)
    freq_index_reshaped = freq_index.reshape(division*cd, -1)
    freq_index = None
    sky_by_cap_reshaped = sky_by_cap.reshape(freq_index_reshaped.shape)
    to_bin_emit = np.zeros(freq_index_reshaped.shape)
    row_index = np.arange(freq_index_reshaped.shape[0]).reshape(-1, 1)
    np.add.at(to_bin_emit, (row_index, freq_index_reshaped), sky_by_cap_reshaped)
    to_bin_emit = to_bin_emit.reshape(fre_boost.shape)
    to_bin_emit = np.multiply(to_bin_emit, freq_bins, out=to_bin_emit)
    final_emit[i] = np.sum(to_bin_emit, axis=1)
print(time.time()-time1)

最佳答案

保持代码简单并优化

如果您知道要编写什么算法,请编写一个简单的引用实现。由此,您可以通过两种方式使用 Python。您可以尝试对代码进行矢量化或者您可以编译代码以获得良好的性能。

即使 np.einsumnp.add.at 在 Numba 中实现,任何编译器都很难从您的示例中生成高效的二进制代码。

我重写的唯一内容是一种更有效的标量值数字化方法。

编辑

在 Numba 源代码中,对标量值进行数字化有更有效的实现。

代码

#From Numba source
#Copyright (c) 2012, Anaconda, Inc.
#All rights reserved.

@nb.njit(fastmath=True)
def digitize(x, bins, right=False):
    # bins are monotonically-increasing
    n = len(bins)
    lo = 0
    hi = n

    if right:
        if np.isnan(x):
            # Find the first nan (i.e. the last from the end of bins,
            # since there shouldn't be many of them in practice)
            for i in range(n, 0, -1):
                if not np.isnan(bins[i - 1]):
                    return i
            return 0
        while hi > lo:
            mid = (lo + hi) >> 1
            if bins[mid] < x:
                # mid is too low => narrow to upper bins
                lo = mid + 1
            else:
                # mid is too high, or is a NaN => narrow to lower bins
                hi = mid
    else:
        if np.isnan(x):
            # NaNs end up in the last bin
            return n
        while hi > lo:
            mid = (lo + hi) >> 1
            if bins[mid] <= x:
                # mid is too low => narrow to upper bins
                lo = mid + 1
            else:
                # mid is too high, or is a NaN => narrow to lower bins
                hi = mid

    return lo

@nb.njit(fastmath=True)
def digitize(value, bins):
  if value<bins[0]:
    return 0

  if value>=bins[bins.shape[0]-1]:
    return bins.shape[0]

  for l in range(1,bins.shape[0]):
    if value>=bins[l-1] and value<bins[l]:
      return l

@nb.njit(fastmath=True,parallel=True)
def inner_loop(boost_factor,freq_bins,es):
  res=np.zeros((boost_factor.shape[0],freq_bins.shape[0]),dtype=np.float64)
  for i in nb.prange(boost_factor.shape[0]):
    for j in range(boost_factor.shape[1]):
      for k in range(freq_bins.shape[0]):
        ind=nb.int64(digitize(boost_factor[i,j]*freq_bins[k],freq_bins))
        res[i,ind]+=boost_factor[i,j]*es[j,k]*freq_bins[ind]
  return res

@nb.njit(fastmath=True)
def calc_nb(division,freq_division,cd,boost_factor,freq_bins,es):
  final_emit = np.empty((division, division, freq_division),np.float64)
  for i in range(division):
    final_emit[i,:,:]=inner_loop(boost_factor[i],freq_bins,es)
  return final_emit

性能

(Quadcore i7)
original_code: 118.5s
calc_nb: 4.14s
#with digitize implementation from Numba source
calc_nb: 2.66s

关于python - 改善 python numpy 代码的运行时间,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50459216/

相关文章:

python - GeoDjango + PostGIS 计算错误的距离

python - 仅当找到多个模式时才匹配正则表达式(python)

java - 将动态二维数组插入 TableView

C# Linq - 如果输入不等于任何字符串 []

java - java中比较两个不同大小数组的元素并返回 boolean 数组

sql - 表别名如何影响性能?

javascript - 将函数分配给 Javascript 中的原型(prototype)是否有内存性能优势?

python - 如果名称存储在变量中,则访问 Python 中的模块级函数

python - numpy 数组赋值比 python 列表慢

performance - 查找最大/最小连续异或值