python - 我怎样才能提高这个numpy循环的效率

标签 python optimization numpy

我有一个包含标签的 numpy 数组。我想根据标签的大小和边界框为每个标签计算一个数字。我怎样才能更有效地编写它,以便在大型阵列(~15000 个标签)上使用它是现实的?

A = array([[ 1, 1, 0, 3, 3],
           [ 1, 1, 0, 0, 0],
           [ 1, 0, 0, 2, 2],
           [ 1, 0, 2, 2, 2]] )

B = zeros( 4 )

for label in range(1, 4):
    # get the bounding box of the label
    label_points = argwhere( A == label )
    (y0, x0), (y1, x1) = label_points.min(0), label_points.max(0) + 1

    # assume I've computed the size of each label in a numpy array size_A
    B[ label ] = myfunc(y0, x0, y1, x1, size_A[label])

最佳答案

我并没有真正能够使用一些 NumPy 向量化函数有效地实现这一点,所以也许聪明的 Python 实现会更快。

def first_row(a, labels):
    d = {}
    d_setdefault = d.setdefault
    len_ = len
    num_labels = len_(labels)
    for i, row in enumerate(a):
        for label in row:
            d_setdefault(label, i)
        if len_(d) == num_labels:
            break
    return d

此函数返回一个字典,将每个标签映射到它出现的第一行的索引。将该函数应用于 AA.TA[: :-1]A.T[::-1] 还为您提供第一列以及最后一行和最后一列。

如果您更喜欢列表而不是字典,可以使用 map(d.get, labels) 将字典转换为列表。或者,您可以从一开始就使用 NumPy 数组而不是字典,但是一旦找到所有标签,您将失去提前离开循环的能力。

我很想知道这是否(以及多少)实际上加速了您的代码,但我相信它比您的原始解决方案更快。

关于python - 我怎样才能提高这个numpy循环的效率,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/8245817/

相关文章:

python - 当模块名称中包含 '-' 破折号或连字符时如何导入模块?

Python:寻找非线性方程的多重根

python - NumPy 数组中沿给定轴的一阶差分

java - 这个正则表达式可以进一步优化吗?

python - 无法通过 pandas 中的 iloc 编辑数据框数据

python - Pandas :根据一列中字符串的特定组合选择行对

javascript - 使用 Ajax 添加好友 - Django

python - 用python在网站文章中搜索关键词

python - 在日期时间的月、日、年...上查询 Mongodb

java - 这两段代码中哪一段更好/更快/使用更少的内存?