python - 从 3d 数组中提取 2d 面片

标签 python numpy scikit-learn

scikit-learn 的 extract_patches_2d可用于将 2D 图像 reshape 为补丁集合。 extract_patches是使用 NumPy 的 as_strided 的通用形式。

import numpy as np
from sklearn.feature_extraction import image

ex = np.arange(3 * 3).reshape(3, 3)
image.extract_patches_2d(ex, patch_size=(2, 2))
[[[0 1]
  [3 4]]

 [[1 2]
  [4 5]]

 [[3 4]
  [6 7]]

 [[4 5]
  [7 8]]]

我有一个三维数组a,想从每个“最里面”的二维数组中提取二维补丁,然后找到每个二维补丁的(与轴无关的)平均值。

a = np.arange(2 * 3 * 3).reshape(2, 3, 3)

在这种情况下,我实际上希望首先对每个 (3, 3) 内部数组调用 extract_patches_2d

patches = np.array([image.extract_patches_2d(i, patch_size=(2, 2)) for i in a])

然后找到每个最里面的二维数组(每个补丁)的平均值:

means = patches.reshape(*patches.shape[:-2], -1).mean(axis=-1)
print(means)
[[  2.   3.   5.   6.]
 [ 11.  12.  14.  15.]]

如何对其进行向量化并摆脱上面的 for 循环?这里重要的是,means 第一个维度的大小等于 a 第一个维度的大小。

最佳答案

您可以使用scikit-image as view_as_windows将这些补丁作为 View 放入输入数组中 -

from skimage.util.shape import view_as_windows

size = 2 # patch size
patches = view_as_windows(a, (1,size,size))[...,0,:,:]

这为我们提供了一个 5D 数组作为 patches,我们可以在该数组上沿最后两个轴使用 mean 缩减来实现 3D 输出 -

out = patches.mean((-2,-1))

如果需要将最终输出作为2D输出,请 reshape 以合并最后两个轴 -

out.reshape(a.shape[0],-1)

这也可以利用 sklearnextract_patches :

def inner_means(arr_3d, patch_size):
    """Axis-agnostic mean of each 2d patch.

    Maintains the first dimension of `arr_3d`.

    patch_size: tuple
        Same syntax as the parameter passed to extract_patches_2d
    """
    shape = (1,) + patch_size
    patches = image.extract_patches(arr_3d, shape)[..., 0, :, :].mean((-2, -1))
    return patches.reshape(*patches.shape[:-2], -1)


a = np.arange(2 * 3 * 3).reshape(2, 3, 3)
    print(inner_means(a, patch_size=(2, 2)))

[[  2.   3.   5.   6.]
 [ 11.  12.  14.  15.]]
<小时/>

或者,为了直接获得 block 状平均值,我们可以使用 Scipy 中的卷积工具之一。所以用 fftconvolve -

from scipy.signal import fftconvolve

out = fftconvolve(a, np.ones((1,size,size)),mode='valid')/size**2

或者使用scipy.signal.convolvescipy.ndimage.filters.uniform_filter而不进行除法。

关于python - 从 3d 数组中提取 2d 面片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48709010/

相关文章:

python - 如何使用 python 从该矩阵中删除行/列

python-3.x - 导入错误: cannot import name 'multilabel_confusion_matrix'

python - 使用 scipy.signal.firwin 将高通滤波器应用于 WAV 文件

python - Pyramid :使用 `view_config` 注册的 View 未与路由相关联

python - 如何创建更有效的 bool 逻辑代码来将一列的多行与另一列进行比较?

python - 应用掩码来加速各种数组计算

python - 情感分析管道,使用特征选择时获取正确特征名称的问题

python - Sklearn 高斯混合锁定参数?

python - 如何找到第n次出现在列表中的项目的索引?

python - 如何在Python中记录xbox/游戏 handle Controller 的状态?