python - NumPy - 这可以向量化吗?

标签 python numpy

正在编写新脚本,这是我第一次深入研究 NumPy。内存优势本质上是显而易见的,但矢量化是一个麻烦的概念。

我有 2 个 NumPy 数组代表 XY 点和框,对于每个点,我需要确定该点与哪些框相交。

这两个数组的结构如下:

>>> points
array([[40.00183, 20.005],
       [39.9975, 20.0125],
       [57.01822, 16.997]], dtype=float32)

>>> boxes
array([[40.00183, 20.005, 39.9975, 20.0125],
       [39.9975, 20.0125, 57.01822, 16.997],
       [57.01822, 16.997, 40.00183, 20.005]], dtype=float32)

这里的实际值是虚构的,事实上这些盒子甚至不是盒子,但这就是结构。 points是一个 N 维数组,形状为 (N, 2)boxes形状为(M, 4)

交叉测试的算法是:

def intersect(p: np.ndarray, b: np.ndarray) -> bool:
    '''Intersection testing using DeMorgan's Law'''
    return ( p[0] < b[2] and
             p[0] > b[0] and
             p[1] < b[3] and
             p[1] > b[1] )

我见过的所有向量化都涉及标量,我还没有见过任何涉及使用 2 个数组的函数。

最佳答案

你确实可以!碰巧,我遇到了完全相同的问题,并将解决方案编码如下:

def point_is_inside_box(point, bb):
  '''
  point: (x,y) np array of shape Nx2
  bb: (xmin,ymin,xmax,ymax) np array of shape Mx4

  Return: boolean matrix MxN where each column stands for "point n is in box m"
  '''
  # Logic: xmin <= x < xmax and ymin <= y < ymax
  point = point[None,...]
  bb = bb[...,None,:]
  return (bb[...,0] < point[...,0]) & (point[...,0] < bb[...,2]) & (bb[...,1] < point[...,1]) & (point[...,1] < bb[...,3])

本质上,这个想法是利用 numpy 的广播规则。由于输入是两个向量,因此我添加了维度。 point有形状[1,N,2]bb有形状[M,1,4] 。这样,广播将应用 <每对 (pt, box) 的运算符在数组中,生成形状为 [M,N] 的矩阵形式的结果.

关于切片:

  • ...称为 ellipsis它相当于 :根据需要填充缺失的尺寸。您可以将其视为“从我在此未明确说明的所有其他维度获取所有内容”的捷径。例如,如果 point有形状[42,2] ,我可以选择全部x值来自point[:,0]point[...,0] 。然而,如果point有形状[42,1,2] ,第二个语句仍然会选择所有 x 值,而第一个语句不起作用(需要更改为 point[:,:,0] )

  • None相当于 np.newaxis 。我基本上是告诉 numpy 在该特定位置插入一个新维度。有人可能会争论使用 np.newaxis而不是None更具可读性。他们是对的。

关于内存消耗:

假设在向数组添加额外维度时没有发生任何副本(不确定是否是这种情况,但我猜它不会发生),您将需要额外的内存 N*M bool 值,如果你有很多点和盒子,它可能会变得很多。如果是这种情况,考虑到输出矩阵可能具有很强的稀疏性,尝试使用 scipy 's sparse matrices 可能会很有趣。保持代码结构相同。但不知道这是否有效,或者是否具有高性能。

关于python - NumPy - 这可以向量化吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59967870/

相关文章:

python - 如何在Python中打印字典中的两个字符串

矩阵 x[i,j] 和 x[i][j] 两种形式的python区别

python - NumPy complex128除法与float64除法不一致

python - Linux CI 服务器上的 GAE

Python Numpy - 二维数组中的 3 维索引,无循环

python - 正则表达式删除多个空格后的字符

Python 特征向量 : differences among numpy. linalg、scipy.linalg 和 scipy.sparse.linalg

python - ndb : query(AdModel. daily_used < AdModel.daily_budget) 是否可以执行以下操作

python - 如何根据 ODR 结果计算标准误差?

python - 使用 axis=1 调用 apply 并将不同长度的列表设置为单元格值时出现 Pandas ValueError