python - 更快地实现像素映射

标签 python python-3.x numpy

这个方法非常慢。简短而有趣的是,它接受一个字典phase_color_labels,该字典将任意名称映射到与 RGB 值相对应的 3 元素列表,并将输入图像的每个像素映射到它所指定的任何像素值。最接近于 phase_color_labels 字典。我还没有弄清楚是否有一个运行速度更快的矢量化版本。

image 变量只是一个 numpy 数组 [H、W、Channels]。

def map_pixels_to_discrete_values(image, phase_color_labels):
        """
        Takes an image with floating point pixel values and maps each pixel RGB value
        to a new value based on the closest Euclidean distance to one of the RGB sets
        in the phase_label input dictionary.
        """
        mapped_image = np.copy(image)
        for i in range(mapped_image.shape[0]):
            for j in range(mapped_image.shape[1]):
                min_distance = np.inf
                min_distance_label = None
                for phase_name, phase_color in phase_color_labels.items():
                    r = phase_color[0]
                    g = phase_color[1]
                    b = phase_color[2]
                    rgb_distance = (mapped_image[i, j, 0] - r)**2 + (mapped_image[i, j, 1] - g)**2 + (mapped_image[i, j, 2] - b)**2
                    if rgb_distance < min_distance:
                        min_distance = rgb_distance
                        min_distance_label = phase_name

                mapped_image[i, j, :] = phase_color_labels[min_distance_label]
        return mapped_image

最佳答案

为了使用 Numpy 快速完成任务,您通常希望避免循环并将尽可能多的工作插入 Numpy 的矩阵运算中。

我的回答的基本想法:

  1. phase_color_labels 获取颜色作为 ndarrayphase_colors
  2. 使用 Numpy 的 broadcasting计算“外部距离”——数组中每个图像与phase_colors中每种颜色之间的欧氏距离。
  3. 找到每个像素距离最近的颜色索引,并将其用作 phase_colors 中的索引。
phase_colors = np.array([color for color in phase_color_labels.values()])

distances = np.sqrt(np.sum((image[:,:,np.newaxis,:] - phase_colors) ** 2, axis=3))
min_indices = distances.argmin(2)

mapped_image = phase_colors[min_indices]

第三行需要一些额外的解释。首先,请注意 phase_namesphase_colors 都具有形状 (L, C),其中 L 是标签,C 是 channel 数。

  • image[:,:,np.newaxis,:] 在第二个和第三个轴之间插入一个新轴,因此生成的数组的形状为 (H, W, 1, C )
  • 当从形状为 (H, W, 1, C) 的数组中减去形状为 (L, C) 的数组时,Numpy 将数组广播为形状 (H,W,L,C)。您可以找到有关 Numpy 广播语义的更多详细信息 here .
  • 然后,沿轴 3 求和会生成形状为 (H, W, L) 的数组。
  • (平方或开平方均不会影响数组的形状。)

在第四行中,在轴 2 上使用 argmin,然后将数组缩减为 (H, W) 形状,每个值都是缩减轴的索引 L——换句话说,是 phase_colors 的索引。

<小时/>

作为一项额外的改进,由于平方根是单调递增函数,因此它不会改变最小距离,因此您可以将其完全删除。

请注意,对于大型图像phase_color_labels,广播的内存消耗可能会很明显,这也可能会导致性能问题。

关于python - 更快地实现像素映射,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55463913/

相关文章:

python - 将 Pandas df 乘以 2 个常量,创建 2 个 df

python - 什么更快 : Python 3's ' len' or numpys shape?

arrays - 在 Python 3 中将查找表应用于 NUMPY 数组的最有效方法

Python - 从 URL 获取标题信息

python - __init__() 恰好接受 Airflow dag 任务中给出的 2 个参数 1

python - 通过字符串查找解析全局命名空间中的 Python 对象

python - 在 Python3 中使用 for 循环为 vigenere 密码创建 2D 列表

python - 如何在 python 循环中将子标题添加到 pandas Dataframe?

python - 将其标准化以写入.wav文件后,无法获取原始的numpy数组

python - 将 wlst 命令重定向到 python 脚本中的文件