python - 在 numpy 中对分区索引分组 argmax/argmin

标签 python numpy

Numpy 的 ufunc我们有一个 reduceat 在数组中的连续分区上运行它们的方法。所以不要写:

import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]

我会写:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))

两者都将返回切片中的最大值 a[0:4] , a[4:5] , a[5:10] , 是 [8, 0, 9] .

我想要一个类似的函数来执行 argmax ,注意到我只想在每个分区中有一个单个最大索引:[3, 4, 5]与上述 asplit_at (尽管索引 5 和 9 都获得了最后一组中的最大值),正如

np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]

我将在下面发布一个可能的解决方案,但我希望看到一个无需在组上创建索引即可矢量化的解决方案。

最佳答案

此解决方案涉及在组(上例中的 [0, 0, 0, 0, 1, 2, 2, 2, 2, 2])上构建索引。

group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)

然后我们可以使用:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]

result 中得到 [3, 4, 5][::-1] 确保我们得到每个组中的 第一个 而不是最后一个 argmax。

这依赖于花式赋值中的最后一个索引决定赋值的事实,@seberg says one shouldn't rely on (使用 result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]] 可以实现更安全的替代方案,它涉及对 len(maxima) 的排序~ n_groups 个元素)。

关于python - 在 numpy 中对分区索引分组 argmax/argmin,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22124332/

相关文章:

python - 查找 numpy 切片数组的父数组

python - 根据值数组的条件在 pandas 中分配值

python - scikit-learn GridSearchCV best_score_ 是如何计算的?

python - 将unicode插入sqlite?

python - 通过返回非预定义字段。 Django 中的 Tastypie API

python - 数组随机打乱,但保持对角线固定

python - pandas 中的复杂旋转

python - 矢量化在一个数组中为另一个数组中的每个元素查找最接近的值

python - 从 python 中的多维数组中删除重复条目

python - QGIS 渲染 shapefile 时如何让 Python 等待