python - 计算沿轴的直方图

标签 python performance numpy scipy vectorization

有没有办法沿着 nD 数组的轴计算许多直方图?我目前使用的方法使用 for 循环遍历所有其他轴并为每个生成的一维数组计算 numpy.histogram():

import numpy
import itertools
data = numpy.random.rand(4, 5, 6)

# axis=-1, place `200001` and `[slice(None)]` on any other position to process along other axes
out = numpy.zeros((4, 5, 200001), dtype="int64")
indices = [
    numpy.arange(4), numpy.arange(5), [slice(None)]
]

# Iterate over all axes, calculate histogram for each cell
for idx in itertools.product(*indices):
    out[idx] = numpy.histogram(
        data[idx],
        bins=2 * 100000 + 1,
        range=(-100000 - 0.5, 100000 + 0.5),
    )[0]

out.shape  # (4, 5, 200001)

不用说这很慢,但是我找不到使用 numpy.histogramnumpy.histogram2dnumpy 来解决这个问题的方法。直方图dd.

最佳答案

这是一种使用高效工具的矢量化方法 np.searchsortednp.bincount . searchsorted 为我们提供了每个元素根据 bin 放置的位置,bincount 为我们进行计数。

实现-

def hist_laxis(data, n_bins, range_limits):
    # Setup bins and determine the bin location for each element for the bins
    R = range_limits
    N = data.shape[-1]
    bins = np.linspace(R[0],R[1],n_bins+1)
    data2D = data.reshape(-1,N)
    idx = np.searchsorted(bins, data2D,'right')-1

    # Some elements would be off limits, so get a mask for those
    bad_mask = (idx==-1) | (idx==n_bins)

    # We need to use bincount to get bin based counts. To have unique IDs for
    # each row and not get confused by the ones from other rows, we need to 
    # offset each row by a scale (using row length for this).
    scaled_idx = n_bins*np.arange(data2D.shape[0])[:,None] + idx

    # Set the bad ones to be last possible index+1 : n_bins*data2D.shape[0]
    limit = n_bins*data2D.shape[0]
    scaled_idx[bad_mask] = limit

    # Get the counts and reshape to multi-dim
    counts = np.bincount(scaled_idx.ravel(),minlength=limit+1)[:-1]
    counts.shape = data.shape[:-1] + (n_bins,)
    return counts

运行时测试

原始方法-

def org_app(data, n_bins, range_limits):
    R = range_limits
    m,n = data.shape[:2]
    out = np.zeros((m, n, n_bins), dtype="int64")
    indices = [
        np.arange(m), np.arange(n), [slice(None)]
    ]

    # Iterate over all axes, calculate histogram for each cell
    for idx in itertools.product(*indices):
        out[idx] = np.histogram(
            data[idx],
            bins=n_bins,
            range=(R[0], R[1]),
        )[0]
    return out

时间和验证-

In [2]: data = np.random.randn(4, 5, 6)
   ...: out1 = org_app(data, n_bins=200001, range_limits=(- 2.5, 2.5))
   ...: out2 = hist_laxis(data, n_bins=200001, range_limits=(- 2.5, 2.5))
   ...: print np.allclose(out1, out2)
   ...: 
True

In [3]: %timeit org_app(data, n_bins=200001, range_limits=(- 2.5, 2.5))
10 loops, best of 3: 39.3 ms per loop

In [4]: %timeit hist_laxis(data, n_bins=200001, range_limits=(- 2.5, 2.5))
100 loops, best of 3: 3.17 ms per loop

因为在循环解决方案中,我们循环遍历前两个轴。因此,让我们增加它们的长度,因为这将向我们展示矢量化的有多好 -

In [59]: data = np.random.randn(400, 500, 6)

In [60]: %timeit org_app(data, n_bins=21, range_limits=(- 2.5, 2.5))
1 loops, best of 3: 9.59 s per loop

In [61]: %timeit hist_laxis(data, n_bins=21, range_limits=(- 2.5, 2.5))
10 loops, best of 3: 44.2 ms per loop

In [62]: 9590/44.2          # Speedup number
Out[62]: 216.9683257918552

关于python - 计算沿轴的直方图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44152436/

相关文章:

python - 在 jupyter notebook 中停止 SimpleHttpSever

python - 将动态c lib.so集成到python

performance - Visual Studio 2010 是否受益于四核与双核机器?编译是多线程的吗?

c# - 从 CSV 文件中读取数据

python - 如何计算给定波的频率和时间

python - Numpy:找到蒙版边缘的索引

python - 将 CSV 值转换为 numpy 数组,其中字段作为数组索引

python - [Python][Keras] softmax() 获得意外的关键字参数 'axis'

python - 读取并检查文件中的连续单词

python - 查找 1 到 ~2^128 之间的素数