python - 沿第 3 维从 Numpy ndarray 中删除并提取最小值

标签 python arrays numpy

我有一个形状为 (3, 4, 3) 的数组,它表示一个 3x4 像素的图像和 3 个颜色 channel (最后一个索引)。我的目标是留下一个 (3, 4, 2) 数组,其中每个像素都删除了最低的颜色 channel 。我可以进行逐像素迭代,但这会非常耗时。使用 np.argmin 我可以轻松提取最小值的索引,这样我就知道哪个颜色 channel 包含每个像素的最小值。但是,我找不到一种巧妙的索引方式来删除这些值,因此我只剩下一个 (3, 4, 2) 数组。

此外,我还尝试使用 array[:, :, indexMin)] 之类的方法来选择最小值,但无法获得所需的形状数组 (3, 4 ) 包含每个像素的最小 channel 值。我知道为此有一个函数 np.amin,但它会让我更好地理解 Numpy 数组。下面提供了我的代码结构的最小示例:

import numpy as np

arr = [[1, 2, 3, 4],
       [5, 6, 7, 8],
       [9, 10, 11, 12]]
array = np.zeros((3, 4, 3))
array[:,:,0] = arr
array[:,:,1] = np.fliplr(arr)
array[:,:,2] = np.flipud(arr)
indexMin = np.argmin(array, axis=2)

最佳答案

您需要能够正确广播到您想要的输出形状的数组。您可以使用 np.expand_dims 添加缺失的尺寸。 :

index = np.expand_dims(np.argmin(array, axis=2), axis=2)

这使得设置或提取要删除的元素变得容易:

index = list(np.indices(array.shape, sparse=True))
index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = array[tuple(index)]

np.indices with sparse=True 返回一组范围,这些范围的形状可以在每个维度中正确广播索引。更好的选择是使用 np.take_along_axis :

index = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = np.take_along_axis(array, index, axis=2)

您可以使用这些结果来创建 mask ,例如与 np.put_along_axis :

mask = np.ones(array.shape, dtype=bool)
np.put_along_axis(mask, index, 0, axis=2)

使用掩码索引数组可以得到:

result = array[mask].reshape(*array.shape[:2], -1)

reshape 之所以有效,是因为您的像素存储在最后一个维度中,该维度在内存中应该是连续的。这意味着掩码正确地删除了三个元素中的一个,因此在内存中正确排序。这在屏蔽操作中并不常见。

另一种选择是使用 np.delete用一个乱七八糟的数组和np.ravel_multi_index :

i = np.indices(array.shape[:2], sparse=True)
index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
result = np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)

只是为了好玩,您可以利用每个像素只有三个元素的事实来创建要保留的元素的完整索引。这个想法是所有三个索引的总和是 3。因此,3 - np.argmin(array, axis=2) - np.argmax(array, axis=2) 是中位元素。如果你堆叠中位数和最大值,你会得到一个类似于 sort 给你的索引:

amax = np.argmax(array, axis=2)
amin = np.argmin(array, axis=2)
index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
result = np.take_along_axis(array, index, axis=2)

调用np.clip有必要处理所有元素都相等的情况,在这种情况下 argmaxargmin 都返回零。

时间

比较方法:

def remove_min_indices(array):
    index = list(np.indices(array.shape, sparse=True))
    index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
    mask = np.ones(array.shape, dtype=bool)
    mask[tuple(index)] = False
    return array[mask].reshape(*array.shape[:2], -1)

def remove_min_put(array):
    mask = np.ones(array.shape, dtype=bool)
    np.put_along_axis(mask, np.expand_dims(np.argmin(array, axis=2), axis=2), 0, axis=2)
    return array[mask].reshape(*array.shape[:2], -1)

def remove_min_delete(array):
    i = np.indices(array.shape[:2], sparse=True)
    index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
    return np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)

def remove_min_sort_c(array):
    return np.sort(array, axis=2)[..., 1:]

def remove_min_sort_i(array):
    array.sort(axis=2)
    return array[..., 1:]

def remove_min_median(array):
    amax = np.argmax(array, axis=2)
    amin = np.argmin(array, axis=2)
    index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
    return np.take_along_axis(array, index, axis=2)

测试了像 array = np.random.randint(10, size=(N, N, 3), dtype=np.uint8) for N 这样的数组在 {100, 1K, 10K, 100K, 1M} 中:

  N  |   IND   |   PUT   |   DEL   |   S_C   |   S_I   |   MED   |
-----+---------+---------+---------+---------+---------+---------+
100  | 648. µs | 658. µs | 765. µs | 497. µs | 246. µs | 905. µs |
  1K | 67.9 ms | 68.1 ms | 85.7 ms | 51.7 ms | 24.0 ms | 123. ms |
 10K | 6.86 s  | 6.86 s  | 8.72 s  | 5.17 s  | 2.39 s  | 13.2 s  |
-----+---------+---------+---------+---------+---------+---------+

如预期的那样,时间比例为 N^2。排序返回的结果与其他方法不同,但显然是最有效的。使用 put_along_axis 进行掩码似乎是对较大数组更有效的方法,而原始索引似乎对较小数组更有效。

关于python - 沿第 3 维从 Numpy ndarray 中删除并提取最小值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64336708/

相关文章:

javascript - 如何获得 L.polygon 边界,不包括内部点

python - scipy.optimize.minimize 返回零维数组?

java - 为什么无法执行包含 Java 中的 String[] 的 Julia 脚本

c++ - 像数组一样初始化类对象

python - 跳过没有装饰器语法的单元测试测试

对于 find 命令,Python 子进程未按预期工作

python - Numpy.where 与值列表一起使用

python - 为什么我们可以为 NumPy 数组的形状传递一个列表?形状不是必须是整数或整数元组吗?

python - 有没有办法将 Protocol Buffer 编译成纯 python 代码?

python - Sphinx搜索引擎是否正式支持Python API?