Python Numpy 获取 2 个二维数组之间的差异

标签 python numpy multidimensional-array set-difference

好吧,我有一个让我头疼的简单问题,基本上我有两个二维数组,充满 [x,y] 坐标,我想比较第一个和第二个并生成包含所有没有出现在第二个数组中的第一个数组的元素。这很简单,但我根本无法让它工作。大小变化很大,第一个数组可以有1000到200万个坐标,而第一个数组有1到1000个坐标。
这个操作会发生很多次,第一个数组越大,它发生的次数就越多
样本:

arr1 = np.array([[0, 3], [0, 4], [1, 3], [1, 7], ])

arr2 = np.array([[0, 3], [1, 7]])

result = np.array([[0, 4], [1, 3]])
深入:基本上我有一个分辨率可变的二进制图像,它由 0 和 1 (255) 组成,我单独分析每个像素(使用已经优化的算法),但是(故意)每次执行此函数时它只分析一小部分像素,当它完成时,它会返回这些像素的所有坐标。问题在于,当它执行时,它会运行以下代码:
ones = np.argwhere(img == 255) # ones = pixels array
它大约需要 0.02 秒,是迄今为止代码中最慢的部分。我的想法是创建此变量一次,每次函数结束时,它都会删除解析的像素并将新数组作为参数传递以继续,直到数组为空

最佳答案

不确定您打算对额外维度做什么,因为设置差异与任何过滤一样,本质上会丢失形状信息。
无论如何,NumPy 确实提供了 np.setdiff1d() 优雅地解决这个问题。

编辑 通过提供的说明,您似乎正在寻找一种计算给定轴上的集合差的方法,即集合的元素实际上是数组。
no built-in专门为此在 NumPy 中,但制作一个并不太难。
为简单起见,我们假设操作轴是第一个(因此集合的元素是 arr[i] ),只有唯一的元素出现在第一个数组中,并且数组是二维的。
它们都基于这样一种想法,即渐近最佳方法是构建 set()的第二个数组,然后使用它来过滤掉第一个数组中的条目。
在 Python/NumPy 中构建此类集合的惯用方法是使用:

set(map(tuple, arr))
其中映射到 tuple卡住 arr[i] ,允许它们是可散列的,从而使它们可用于 set() .
不幸的是,由于过滤会产生不可预测大小的结果,NumPy 数组不是结果的理想容器。
要解决这个问题,可以使用:
  • 中级 list
  • import numpy as np
    
    
    def setdiff2d_list(arr1, arr2):
        delta = set(map(tuple, arr2))
        return np.array([x for x in arr1 if tuple(x) not in delta])
    
  • np.fromiter() 其次是 np.reshape()
  • import numpy as np
    
    
    def setdiff2d_iter(arr1, arr2):
        delta = set(map(tuple, arr2))
        return np.fromiter((x for xs in arr1 if tuple(xs) not in delta for x in xs), dtype=arr1.dtype).reshape(-1, arr1.shape[-1])
    
  • NumPy's advanced indexing
  • def setdiff2d_idx(arr1, arr2):
        delta = set(map(tuple, arr2))
        idx = [tuple(x) not in delta for x in arr1]
        return arr1[idx]
    
  • 将两个输入都转换为 set() (将强制输出元素的唯一性并失去排序):
  • import numpy as np
    
    
    def setdiff2d_set(arr1, arr2):
        set1 = set(map(tuple, arr1))
        set2 = set(map(tuple, arr2))
        return np.array(list(set1 - set2))
    
    或者,可以使用 broadcasting 构建高级索引。 , np.any() np.all() :
    def setdiff2d_bc(arr1, arr2):
        idx = (arr1[:, None] != arr2).any(-1).all(1)
        return arr1[idx]
    
    上述方法的某种形式最初是在 @QuangHoang's answer 中提出的。 .
    也可以在 Numba 中实现类似的方法,遵循与上述相同的想法,但使用哈希而不是实际的数组 View arr[i] (由于 Numba set() 中支持的内容的限制)和预先计算输出大小(为了速度):
    import numpy as np
    import numba as nb
    
    
    @nb.njit
    def mul_xor_hash(arr, init=65537, k=37):
        result = init
        for x in arr.view(np.uint64):
            result = (result * k) ^ x
        return result
    
    
    @nb.njit
    def setdiff2d_nb(arr1, arr2):
        # : build `delta` set using hashes
        delta = {mul_xor_hash(arr2[0])}
        for i in range(1, arr2.shape[0]):
            delta.add(mul_xor_hash(arr2[i]))
        # : compute the size of the result
        n = 0
        for i in range(arr1.shape[0]):
            if mul_xor_hash(arr1[i]) not in delta:
                n += 1
        # : build the result
        result = np.empty((n, arr1.shape[-1]), dtype=arr1.dtype)
        j = 0
        for i in range(arr1.shape[0]):
            if mul_xor_hash(arr1[i]) not in delta:
                result[j] = arr1[i]
                j += 1
        return result
    
    虽然它们都给出了相同的结果:
    funcs = setdiff2d_iter, setdiff2d_list, setdiff2d_idx, setdiff2d_set, setdiff2d_bc, setdiff2d_nb
    
    arr1 = np.array([[0, 3], [0, 4], [1, 3], [1, 7]])
    print(arr1)
    # [[0 3]
    #  [0 4]
    #  [1 3]
    #  [1 7]]
    
    arr2 = np.array([[0, 3], [1, 7], [4, 0]])
    print(arr2)
    # [[0 3]
    #  [1 7]
    #  [4 0]]
    
    result = funcs[0](arr1, arr2)
    print(result)
    # [[0 4]
    #  [1 3]]
    
    for func in funcs:
        print(f'{func.__name__:>24s}', np.all(result == func(arr1, arr2)))
    #           setdiff2d_iter True
    #           setdiff2d_list True
    #            setdiff2d_idx True
    #            setdiff2d_set False  # because of ordering
    #             setdiff2d_bc True
    #             setdiff2d_nb True
    
    他们的表现似乎各不相同:
    for func in funcs:
        print(f'{func.__name__:>24s}', end='  ')
        %timeit func(arr1, arr2)
    #           setdiff2d_iter  16.3 µs ± 719 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #           setdiff2d_list  14.9 µs ± 528 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #            setdiff2d_idx  17.8 µs ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #            setdiff2d_set  17.5 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #             setdiff2d_bc  9.45 µs ± 405 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #             setdiff2d_nb  1.58 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    
    提议的基于 Numba 的方法似乎以相当大的幅度优于其他方法(使用给定输入大约 10 倍)。
    使用更大的输入观察到类似的时间:
    np.random.seed(42)
    
    arr1 = np.random.randint(0, 100, (1000, 2))
    arr2 = np.random.randint(0, 100, (1000, 2))
    print(setdiff2d_nb(arr1, arr2).shape)
    # (736, 2)
    
    
    for func in funcs:
        print(f'{func.__name__:>24s}', end='  ')
        %timeit func(arr1, arr2)
    #           setdiff2d_iter  3.51 ms ± 75.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    #           setdiff2d_list  2.92 ms ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    #            setdiff2d_idx  2.61 ms ± 38.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    #            setdiff2d_set  3.52 ms ± 67.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    #             setdiff2d_bc  25.6 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    #             setdiff2d_nb  192 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    (作为旁注,setdiff2d_bc() 受第二个输入大小的负面影响最大)。

    关于Python Numpy 获取 2 个二维数组之间的差异,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66674537/

    相关文章:

    python - 根据标题中是否包含特定字母来删除 Dataframe 中的列

    python - 如何为 openCV SVM 格式化数据

    python - 类型错误 : only length-1 arrays can be converted to Python scalars with NUMPY

    c++ - 将多维数组传递给参数类型为 double * 的函数

    python 在 mac os 10.10.1 上安装 lxml

    python - 我如何知道登录网络 session 的要求?

    python - Tornado 异步 http 客户端 block

    python - 在 numpy 中从一个数组到另一个数组获取距下一个日期的天数

    javascript - 多维数组javascript的最大值

    python - 根据两列的值选择 numpy ndarray 中的行