python - 在 NumPy 中使用多级 bool 索引掩码

标签 python numpy

我有以下代码,它首先选择具有逻辑索引掩码的 NumPy 数组的元素:

import numpy as np

grid = np.random.rand(4,4) 
mask = grid > 0.5

我希望针对这个使用第二个 bool 掩码来挑选对象:

masklength = len(grid[mask])
prob = 0.5
# generates an random array of bools
second_mask = np.random.rand(masklength) < prob 

# this fails to act on original object
grid[mask][second_mask] = 100

这与 SO 问题中列出的问题不完全相同: Numpy array, how to select indices satisfying multiple conditions? - 因为我正在使用随机数生成,所以我不想生成一个完整的掩码,只为第一个掩码选择的元素生成。

最佳答案

使用平面索引可以避免很多麻烦:

grid.flat[np.flatnonzero(mask)[second_mask]] = 100

分解:

ind = np.flatnonzero(mask)

生成一个平面索引数组,其中 mask 为真,然后通过应用 second_mask 进一步减少:

ind = ind[second_mask] 

我们可以继续:

ind = ind[third_mask]

最后

grid.flat[ind] = 100

使用ind 索引grid 的平面版本并分配100grid.ravel()[ind] = 100 也可以,因为 ravel() 返回原始数组的平面 View 。

关于python - 在 NumPy 中使用多级 bool 索引掩码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/7179532/

相关文章:

使用 anaconda 安装 Python 3.7 失败

python - 更新字典列表中的列表值

Python:递归删除超过x天的文件夹

python - PyQt 用字典发出信号

python - 调整列表的 numpy 数组的大小,以便所有列表都具有相同的长度,并且可以正确推断 numpy 数组的 dtype

python - 按升序对 pandas DataMatrix 进行排序

python - python 打印输出不一致

python - 根据本地时间计算 24 小时周期内每分钟的平均销售额 (HH :MM)

python - 计算落在一组 x、y、z 坐标之间的值的数量

Python/Numpy - 屏蔽数组非常慢