python - NumPy,如何有效地进行涉及数组其他元素的元素明智的操作(无循环)

标签 python numpy multidimensional-array

您好)我正在尝试使用 NumPy 操作循环来重构我的代码,以使代码更快。任何线索如何做到这一点?此代码根据 2D ndarray 中相邻元素的值为每个元素分配一个值,我找不到针对此类特定人员的任何答案。

这是用于在此处描述的照片上查找鞍点的 6 近邻法的实现 https://documentcloud.adobe.com/link/track?uri=urn:aaid:scds:US:978c30d2-4888-491c-85c1-3949ea6166e9

它接受当前元素与其相邻元素的差异。然后它计算这些差异的符号变化,如果它 >= 4 则它是鞍点。

完全没有循环是可能的吗?

抱歉,如果问题不清楚或格式不正确 - 这是我在 StackOverflow 上发布的第一个问题

def findSaddlePoints6neibours(gray):
    gray = gray.astype(int)
    h = gray.shape[0]
    w = gray.shape[1]

    number = 0
    result = np.zeros((h, w))

    for y in range(1, h - 1):
        for x in range(1, w - 1):
            center = gray[y][x]
            neiboursDiff = []
            neiboursDiff.append(gray[y-1][x] - center)
            neiboursDiff.append(gray[y-1][x+1] - center)
            neiboursDiff.append(gray[y][x+1] - center)
            neiboursDiff.append(gray[y+1][x] - center)
            neiboursDiff.append(gray[y+1][x-1] - center)
            neiboursDiff.append(gray[y][x-1] - center)



            changes = 0
            for i in range(5):
                if (neiboursDiff[i] < 0 and neiboursDiff[i+1] > 0) or (neiboursDiff[i] > 0 and neiboursDiff[i+1] < 0):
                    changes += 1
            if (neiboursDiff[0] < 0 and neiboursDiff[5] > 0) or (neiboursDiff[0] > 0 and neiboursDiff[5] < 0):
                changes += 1
            if changes > 3:
                number += 1
                result[y][x] = 1

    return [result, number]

最佳答案

这是一个矢量化解决方案:

import numpy as np

def findSaddlePoints6neibours_vec(gray):
    gray = np.asarray(gray, dtype=int)
    center = gray[1:-1, 1:-1]
    diffs = [
        gray[:-2, 1:-1],
        gray[:-2, 2:],
        gray[1:-1, 2:],
        gray[2:, 1:-1],
        gray[2:, :-2],
        gray[1:-1, :-2],
    ]
    diffs.append(diffs[0])
    diffs = np.stack(diffs)
    diffs -= center
    sign_changes = np.count_nonzero(diffs[:-1] * diffs[1:] < 0, axis=0)
    is_saddle = sign_changes > 3
    number = np.count_nonzero(is_saddle)
    result = np.pad(is_saddle, ((1, 1), (1, 1)), mode='constant').astype(int)
    return result, number

快速测试:

import numpy as np

# Make example input
np.random.seed(100)
gray = np.random.randint(-10, 10, size=(80, 100))

# The original function
result1, number1 = findSaddlePoints6neibours(gray)
# The vectorized function
result2, number2 = findSaddlePoints6neibours_vec(gray)
# Check results match
print(number1 == number2)
# True
print(np.all(result1 == result2))
# True

# Compare run times
%timeit findSaddlePoints6neibours(gray)
# 31.1 ms ± 682 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit findSaddlePoints6neibours_vec(gray)
# 247 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

编辑:

上述函数的缺点是占用内存较多。如果可以使用 Numba,则可以编译该函数并使用并行化使其更快:

import numba as nb

@nb.njit(parallel=True)
def findSaddlePoints6neibours_nb(gray):
    gray = gray.astype(np.int32)
    h = gray.shape[0]
    w = gray.shape[1]
    number = 0
    result = np.zeros((h, w), dtype=np.int32)
    neiboursDiff = np.empty(7, dtype=np.int32)
    for y in nb.prange(1, h - 1):
        for x in np.prange(1, w - 1):
            neiboursDiff[0] = gray[y-1][x]
            neiboursDiff[1] = gray[y-1][x+1]
            neiboursDiff[2] = gray[y][x+1]
            neiboursDiff[3] = gray[y+1][x]
            neiboursDiff[4] = gray[y+1][x-1]
            neiboursDiff[5] = gray[y][x-1]
            neiboursDiff[6] = neiboursDiff[0]
            neiboursDiff -= gray[y, x]
            changes = np.sum(neiboursDiff[:-1] * neiboursDiff[1:] < 0)
            is_saddle = int(changes > 3)
            number += is_saddle
            result[y, x] = is_saddle
    return result, number

继续上面的小基准测试:

%timeit findSaddlePoints6neibours_nb(gray)
# 114 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

关于python - NumPy,如何有效地进行涉及数组其他元素的元素明智的操作(无循环),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53721170/

相关文章:

python - 无法覆盖由monkeypatch 为单例类设置的函数的初始值

python - Python定义二维数组的耗时

python - 如何修复 AttributeError : module 'numpy' has no attribute 'square'

java - 初始化二维数组

c - 如何在c中连接多个二维数组

python - 在 Pandas itertuples() 中,字符串 'class' 在namedtuple 中转换为 '_1'

python - 选择矩阵 pandas python 上的最高行。

python - Numpy 元素明智的除法没有按预期工作

python - 多维数组 : index out of range

python - 在 stdout 上打印但无法写入 Python 中的文件