python - 替代 `numpy.tile` 用于周期性掩码

标签 python numpy

我有一张图像,存储在 uint8 的 numpy 数组中,形状为 (planes, rows, cols)。我需要将它与存储在掩码中的值进行比较,掩码也是 uint8,形状为 (mask_rows, mask_cols)。虽然图像可能非常大,但掩码通常很小,通常是 (256, 256) 并且平铺在 image 上。为了简化代码,我们假设 rows = 100 * mask_rowscols = 100 * mask_cols

我目前处理这个阈值的方式是这样的:

out = image >= np.tile(mask, (image.shape[0], 100, 100))

在遇到 MemoryError 打脸之前,我可以用这种方式处理的最大数组比 (3, 11100, 11100) 大一点。按照我的想法,按照这种方式做事,我最多可以在内存中共存三个巨大的数组:image、平铺的 mask 和我的返回 out。但是平铺掩码是同一个小数组被复制了 10,000 多次。因此,如果我可以节省该内存,我将只使用 2/3 的内存,并且应该能够处理大 3/2 的图像,因此大小约为 (3, 13600, 13600)。顺便说一句,这与我使用

进行阈值处理时得到的结果一致
np.greater_equal(image, (image.shape[0], 100, 100), out=image)

我(失败的)尝试利用 mask 的周期性特性来处理更大的数组是用周期性线性数组索引 mask:

mask = mask[None, ...]
rows = np.tile(np.arange(mask.shape[1], (100,))).reshape(1, -1, 1)
cols = np.tile(np.arange(mask.shape[2], (100,))).reshape(1, 1, -1)
out = image >= mask[:, rows, cols]

对于小型阵列,它确实产生了与另一个阵列相同的结果,尽管速度降低了 20 倍(!!!),但对于较大的阵列,它的表现却非常糟糕。而不是 MemoryError 它最终使 python 崩溃,即使对于其他方法处理的值也没有问题。

我认为发生的事情是 numpy 实际上正在构造 (planes, rows, cols) 数组来索引 mask,所以不仅没有节省内存,但由于它是一个 int32 数组,它实际上需要四倍多的空间来存储...

关于如何解决这个问题有什么想法吗?为了省去您的麻烦,请在下面找到一些沙盒代码来尝试一下:

import numpy as np

def halftone_1(image, mask) :
    return np.greater_equal(image, np.tile(mask, (image.shape[0], 100, 100)))

def halftone_2(image, mask) :
    mask = mask[None, ...]
    rows = np.tile(np.arange(mask.shape[1]),
                   (100,)).reshape(1, -1, 1)
    cols = np.tile(np.arange(mask.shape[2]),
                   (100,)).reshape(1, 1, -1)
    return np.greater_equal(image, mask[:, rows, cols])

rows, cols, planes = 6000, 6000, 3
image = np.random.randint(-2**31, 2**31 - 1, size=(planes * rows * cols // 4))
image = image.view(dtype='uint8').reshape(planes, rows, cols)
mask = np.random.randint(256,
                         size=(1, rows // 100, cols // 100)).astype('uint8')

#np.all(halftone_1(image, mask) == halftone_2(image, mask))
#halftone_1(image, mask)
#halftone_2(image, mask)

import timeit
print timeit.timeit('halftone_1(image, mask)',
                    'from __main__ import halftone_1, image, mask',
                    number=1)
print timeit.timeit('halftone_2(image, mask)',
                    'from __main__ import halftone_2, image, mask',
                    number=1)

最佳答案

我差点给你指点一个 rolling window类型的技巧,但对于这个简单的非重叠事物,正常 reshape 也一样。 (这里的 reshape 是安全的,numpy 永远不会为他们制作副本)

def halftone_reshape(image, mask):
    # you can make up a nicer reshape code maybe, it is a bit ugly. The
    # rolling window code can do this too (but much more general then reshape).
    new_shape = np.array(zip(image.shape, mask.shape))
    new_shape[:,0] /= new_shape[:,1]
    reshaped_image = image.reshape(new_shape.ravel())

    reshaped_mask = mask[None,:,None,:,None,:]

    # and now they just broadcast:
    result_funny_shaped = reshaped_image >= reshaped_mask

    # And you can just reshape it back:
    return result_funny_shaped.reshape(image.shape)

因为时间就是一切(不是真的但是......):

In [172]: %timeit halftone_reshape(image, mask)
1 loops, best of 3: 280 ms per loop

In [173]: %timeit halftone_1(image, mask)
1 loops, best of 3: 354 ms per loop

In [174]: %timeit halftone_2(image, mask)
1 loops, best of 3: 3.1 s per loop

关于python - 替代 `numpy.tile` 用于周期性掩码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/14466534/

相关文章:

python - 如何查找在 Django-admin 中存储为整数的自定义 ip 地址字段?

javascript - iPhone 的 Python 到 JavaScript

python - 需要一些关于如何实现基于 golang 的 restful api 应用程序的帮助

python - 以不同顺序对多个列上的结构化 Numpy 数组进行排序

Python 和 Numba : incorrect checksum for freed object

python - 是否可以使用 NumPy 或 Python 中的其他包获取 double 二进制浮点的保留特殊数字

python - 多对多使用 Flask-SQLAlchemy 返回原始 sql 而不是执行

python - 无法在 sympy 中绘制 d(e^-|t|)/dt 的傅立叶变换

python - 使用 Matplotlib 和 Numpy,有没有办法找到线性方程的所有直线交点?

python - numpy中的加权协方差矩阵