python - NumPy 数组的阈值像素索引

标签 python arrays numpy vectorization

我确定这个问题可以用 Google 搜索,但我不知道要使用什么关键字。我对一个具体案例很好奇,但也对一般情况下如何做感到好奇。假设我有一个 RGB 图像作为形状数组 (width, height, 3) 我想找到红色 channel 大于 100 的所有像素。我觉得 image > [100, 0, 0] 应该给我一个索引数组(如果我正在比较标量并使用灰度图像,则会给我),但这会将每个元素与列表进行比较。如何比较前两个维度,其中每个“元素”是最后一个维度?

最佳答案

要仅检测红色 channel ,您可以这样做 -

np.argwhere(image[:,:,0] > threshold)

解释:

  1. red-channelthreshold 进行比较,得到一个与没有第三轴(颜色 channel )的输入图像形状相同的 bool 数组。
  2. 使用np.argwhere获取成功匹配的索引。

如果您想查看任何 channel 是否高于某个阈值,请使用 .any(-1) (沿最后一个轴/颜色 channel 满足条件的任何元素)。

np.argwhere((image > threshold).any(-1))

样本运行

输入图片:

In [76]: image
Out[76]: 
array([[[118,  94, 109],
        [ 36, 122,   6],
        [ 85,  91,  58],
        [ 30,   2,  23]],

       [[ 32,  47,  50],
        [  1, 105, 141],
        [ 91, 120,  58],
        [129, 127, 111]]], dtype=uint8)

In [77]: threshold
Out[77]: 100

案例 #1:仅红色 channel

In [69]: np.argwhere(image[:,:,0] > threshold)
Out[69]: 
array([[0, 0],
       [1, 3]])

In [70]: image[0,0]
Out[70]: array([118,  94, 109], dtype=uint8)

In [71]: image[1,3]
Out[71]: array([129, 127, 111], dtype=uint8)

案例 #2:任意 channel

In [72]: np.argwhere((image > threshold).any(-1))
Out[72]: 
array([[0, 0],
       [0, 1],
       [1, 1],
       [1, 2],
       [1, 3]])

In [73]: image[0,1]
Out[73]: array([ 36, 122,   6], dtype=uint8)

In [74]: image[1,1]
Out[74]: array([  1, 105, 141], dtype=uint8)

In [75]: image[1,2]
Out[75]: array([ 91, 120,  58], dtype=uint8)

np.any 更快的替代方案在 np.einsum

np.einsum 可能被欺骗 来执行np.any 的工作,而且事实证明速度稍快。

因此,boolean_arr.any(-1) 将等同于 np.einsum('ijk->ij',boolean_arr)

以下是各种数据大小的相关运行时 -

In [105]: image = np.random.randint(0,255,(30,30,3)).astype('uint8')
     ...: %timeit np.argwhere((image > threshold).any(-1))
     ...: %timeit np.argwhere(np.einsum('ijk->ij',image>threshold))
     ...: out1 = np.argwhere((image > threshold).any(-1))
     ...: out2 = np.argwhere(np.einsum('ijk->ij',image>threshold))
     ...: print np.allclose(out1,out2)
     ...: 
10000 loops, best of 3: 79.2 µs per loop
10000 loops, best of 3: 56.5 µs per loop
True

In [106]: image = np.random.randint(0,255,(300,300,3)).astype('uint8')
     ...: %timeit np.argwhere((image > threshold).any(-1))
     ...: %timeit np.argwhere(np.einsum('ijk->ij',image>threshold))
     ...: out1 = np.argwhere((image > threshold).any(-1))
     ...: out2 = np.argwhere(np.einsum('ijk->ij',image>threshold))
     ...: print np.allclose(out1,out2)
     ...: 
100 loops, best of 3: 5.47 ms per loop
100 loops, best of 3: 3.69 ms per loop
True

In [107]: image = np.random.randint(0,255,(3000,3000,3)).astype('uint8')
     ...: %timeit np.argwhere((image > threshold).any(-1))
     ...: %timeit np.argwhere(np.einsum('ijk->ij',image>threshold))
     ...: out1 = np.argwhere((image > threshold).any(-1))
     ...: out2 = np.argwhere(np.einsum('ijk->ij',image>threshold))
     ...: print np.allclose(out1,out2)
     ...: 
1 loops, best of 3: 833 ms per loop
1 loops, best of 3: 640 ms per loop
True

关于python - NumPy 数组的阈值像素索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35301262/

相关文章:

c# - ArrayList 与对象数组与 T 的集合

python - 将 CSV 文件转换为 pandas 的 'flat file"

python - 当使用 eventlet 运行 celery 时,Fabric 失败并显示 Name lookup failed

python - "Browser Not Supported"使用BeautifulSoup进行网页抓取时出错

arrays - Powershell搜索文本文件以进行匹配,并在行尾添加回车符

python - ImportError : numpy. core.multiarray 在使用 mod_wsgi 时导入失败

python - 绘制两个 datetime64[ns] 之间的差异

python - pandas:组合行中的文本

python - matplotlibrc 对情节没有影响?

c - C中二维数组内的一维数组比较