python - 比较 n 维 numpy 数组中的值

标签 python arrays numpy n-dimensional

需要比较 numpy 数组中的每个值,最大的值返回 1,其他值返回 0。我遇到了不同数量的 [] 的问题。

输入示例:

[[[0.6673975 0.33333233]]
.
.
.
[[0.33260247 0.6673975]]]

预期输出:

[[[1 0]]
.
.
.
[[0 1]]]

最佳答案

最大跨轴:

如果按照 Joe 在评论中的建议,您正在寻找沿轴的最大值,那么对于轴 axis

np.moveaxis((np.moveaxis(ar, axis, 0) == ar.max(axis)).astype(int), 0, axis)

或者,更快一点,

(ar == np.broadcast_to(np.expand_dims(ar.max(axis), axis), ar.shape)).astype(int)

应涵盖 n 维情况。

例如:

ar = np.random.randint(0, 100, (2, 3, 4))

ar
Out[157]: 
array([[[17, 28, 22, 31],
        [99, 51, 65, 65],
        [46, 24, 93,  4]],

       [[ 5, 84, 85, 79],
        [ 7, 80, 27, 25],
        [46, 80, 90,  3]]])

(ar == np.broadcast_to(np.expand_dims(ar.max(-1), -1), ar.shape)).astype(int)
Out[159]: 
array([[[0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 0, 1, 0]],

       [[0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0]]])
ar.max(-1)
Out[160]: 
array([[31, 99, 93],
       [85, 80, 90]])

整个数组的最大值:

您偶尔会尝试识别等于整个数组中最大值的元素,

(ar == ar.max()).astype(int)

应该给出您正在寻找的内容。

关于python - 比较 n 维 numpy 数组中的值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47561443/

相关文章:

python - 在Python中如何检查空字符串是否在单词中评估为true?

python 2.7 : Print a dictionary without brackets and quotation marks

java - JSONArray 在编码时丢失

python - 计算忽略 NaN 值的行的最小值

python - 尝试/除外未捕获错误

python - 在 Django 中使用一个模型过滤另一个模型

Python 检查数组中元素的值

Delphi:对类型 "array of TObject"的参数进行 SetLength()

c - 求数组中两个数的平方和

python - PyQ 中的 Kdb 数据库到 NumPy 数组