python - NumPy:将索引选择功能应用于多个轴

标签 python arrays numpy indexing

我正在编写代码来优化取决于可变数量参数的数量。为了优化,我想一次跨多个轴应用索引选择函数,例如 numpy.argmax 和 numpy.argmin。下面是我现在正在使用的代码。是否有更内置或更有效的方法来跨任意数量的轴执行此任务,这些轴可能是顺序的,也可能不是顺序的?

def nd_arg_axes(func, array, start):
    """Applies an index selecting function over trailing axes from start."""

    n_trail = len(array.shape[start:])  # Number of trailing axes to apply to.

    indices = np.zeros((n_trail,)+array.shape[:start], dtype=np.intp)
    for i in np.ndindex(array.shape[:start]):
        indices[(Ellipsis,)+i] = np.unravel_index(func(array[i]),
                                               array.shape[start:])
    return tuple(indices)

# Test showing nd_arg_axes does indeed return the correct indices.
array = np.arange(27).reshape(3,3,3)
max_js = nd_arg_axes(np.argmax, array, 1)

(array[tuple(np.indices(array))+max_js] ==
np.squeeze(np.apply_over_axes(np.amax, array, axes=[1,2])))

最佳答案

如果您选择了尾随轴,您可以将尾随轴的形状 reshape 为 -1,并将 func 应用于 axis=-1:

def f(func, array, start):
    shape = array.shape
    tmp = array.reshape(shape[:start] + (-1,))
    indices = func(tmp, axis=-1)
    return np.unravel_index(indices, shape[start:])

关于python - NumPy:将索引选择功能应用于多个轴,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/7315747/

相关文章:

java - 这个 Java 代码片段 "work"是怎么回事?

python - 如何返回numpy结构化数组中多列的 View

python - 什么是 np.uint8?

python - 调用 plt.show() 后如何保持之前的图而不清除图形而不简单地重复调用?

python - 使用属性来回避两个模型

javascript - 将 php 关联数组转换为字符串时遇到问题

python - 树分类器的极高准确度指标

Java HashSet 到数组

python - 将其他 html 文档集成到 sphinx 文档中

python - 如何使用 DEAP 定义遵循特定顺序模式的自定义遗传算法个体