python - NumPy 沿多维数组轴的最近值

标签 python arrays numpy

我想创建一个函数,它将数组中沿指定轴返回给定值的最近值。

为了获取最接近值的索引,我使用以下代码,其中 arr 是多维数组,value 是要查找的值:

def nearest_index( arr, value, axis=None ):
    return ( np.abs( arr - value ) ).argmin( axis=axis )

但是我很难使用这个函数的结果来从数组中获取值。

一维数组很容易:

In [14]: arr_1 = np.random.randint( 10, 100, size=( 10, ) )

In [15]: arr_1
Out[15]: array([67, 49, 90, 29, 60, 80, 31, 55, 29, 10])

In [16]: nearest_index( arr_1, 50 )
Out[16]: 1

In [17]: arr_1[nearest_index( arr_1, 50 )]
Out[17]: 49

或者使用展平数组:

In [25]: arr_3 = np.random.randint( 10, 100, size=( 2, 3, 4, ) )

In [26]: arr_3
Out[26]:
array([[[85, 51, 74, 79],
        [63, 42, 27, 75],
        [89, 68, 80, 63]],

       [[85, 72, 74, 16],
        [85, 22, 47, 78],
        [44, 70, 98, 34]]])

In [27]: idx_flat = nearest_index( arr_3, 50, axis=None )

In [28]: idx_flat
Out[28]: 1

In [29]: idx = np.unravel_index( idx_flat, arr_3.shape )

In [30]: idx
Out[30]: (0, 0, 1)

In [31]: arr_3[idx]
Out[31]: 51

如何创建一个函数,返回沿定义轴的值?

我尝试了这个解决方案 Question ,但我只让它适用于 axis=-1

请注意,如果数组中的多个元素同样接近预期值,则只要找到结果的第一次出现,这对我来说不是问题。

最佳答案

对于多维数组,我们需要使用高级索引。因此,对于通用的 n-dim 数组和指定的轴,我们可以这样做 -

def argmin_values_along_axis(arr, value, axis):   
    argmin_idx = np.abs(arr - value).argmin(axis=axis)
    shp = arr.shape
    indx = list(np.ix_(*[np.arange(i) for i in shp]))
    indx[axis] = np.expand_dims(argmin_idx, axis=axis)
    return np.squeeze(arr[indx])

示例运行 -

In [203]: arr_3 = np.random.randint( 10, 100, size=( 2, 3, 4, ) )

In [204]: arr_3
Out[204]: 
array([[[94, 55, 26, 51],
        [82, 66, 80, 66],
        [96, 54, 93, 57]],

       [[59, 28, 95, 56],
        [47, 48, 17, 77],
        [15, 57, 57, 25]]])

In [205]: argmin_values_along_axis(arr_3, value=50, axis=0)
Out[205]: 
array([[59, 55, 26, 51],
       [47, 48, 80, 66],
       [15, 54, 57, 57]])

In [206]: argmin_values_along_axis(arr_3, value=50, axis=1)
Out[206]: 
array([[82, 54, 26, 51],
       [47, 48, 57, 56]])

In [207]: argmin_values_along_axis(arr_3, value=50, axis=2)
Out[207]: 
array([[51, 66, 54],
       [56, 48, 57]])

关于python - NumPy 沿多维数组轴的最近值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45857760/

相关文章:

python - 在Windows中连接MySQL到Python时出现缺少protobuf的错误

python - pip如何使包可执行?

java - 从 vector 或数组创建动态形式

python - 分配给二维 NumPy 数组的切片

python - 获取 pandas 中聚合的聚合

python - 您可以使用 Python 开放对 EC2 实例上的 UDP 端口的访问吗?

python - 为什么我们需要使用rabbitmq

java - 由于我的黑客等级解决方案超时而终止

c++ - C/C++ 中超大静态数组的算术运算

python - 二维 Numpy 数组的边值