python - numba 的表现(我做错了什么吗?)

标签 python performance graph numba

我目前正在用Python编写一个遍历图的算法。该图连接到底层方程系统,在遍历过程中,我必须提取并存储一些索引。我一开始是用networkx实现的,但是由于方程组和连通图都很大,所以算法太慢了。

然后我切换到纯 numpy 实现。这更快了,但仍然不够快。我以为numba会更快,但似乎更慢。在测量计算时间后,我注意到主要问题出现在调用以下函数时:

@jit(nopython=True)
def add_col_to_mad_schedule_numba(rows, cols, data,  mad_schedule_col, edges, col, total_fillin_rows, total_fillin_cols,
                            total_fillin_data, nodes):

    colidx = nbfunc.where_single(cols, col)
    for k in range(len(rows[colidx])):
        if nbfunc.contained(nodes, rows[colidx][k]):
            edge_idx = nbfunc.where_single(edges[:, 0], rows[colidx][k])
            if edges[edge_idx].size != 0:
                kstart = edges[edge_idx, 0]
                ending = False
                while ending == False:
                    edges, k_filter = det_edges(kstart, edges)
                    k_filter = np.array(k_filter)
                    if k_filter.size == 0:
                        ending=True
                    else:
                        rows, cols, data, total_fillin_rows, total_fillin_cols, total_fillin_data = det_fillin(rows, cols, data, total_fillin_rows, total_fillin_cols, total_fillin_data, col, k_filter, edges)
                        first_idx, sec_idx, third_idx = mad_numba(rows, cols, edges, k_filter, col)

                        newl = np.zeros((len(edges[k_filter, 1]), 6), dtype=np.int64)
                        newl[:, 0] = edges[k_filter, 0]
                        newl[:, 1] = edges[k_filter, 1]
                        newl[:, 2] = first_idx
                        newl[:, 3] = sec_idx
                        newl[:, 4] = third_idx
                        newl[:, 5] = col


                        mad_schedule_col = np.append(mad_schedule_col, newl, axis=0)

                        kstart = edges[k_filter, 1]


    return rows, cols, data, total_fillin_rows, total_fillin_cols, total_fillin_data, mad_schedule_col[1:]

该函数被调用 n 次,其中 n 是方程系统中变量的数量。目前该函数的每次运行需要 61 毫秒,我想请问您是否可以看到由于 numba 的错误使用而出现的任何技术瓶颈。例如,我仍在函数体中创建 numpy 数组。这样的事情可能会导致性能不佳吗?

该算法确实相当耗时,因为对于系统每列(变量 k)中的每个非零条目,都会遍历有向图,直到没有后继者为止。 遍历次数并没有那么高。 while 循环中有约 3 次迭代。对于每一列,也只有 3-5 个非零条目。

我还可以提供 det_fillin() 和 mad_numba() 的内容,但我认为不会发生很多事情。我使用我自己的 numba 等价的 numpy where() 函数检索一些索引。

请注意,nbfunc 函数也表示与 numpy 函数等效的函数。 where_single() 对应于 np.where , contains() 只是检查 rows[colidx][k] 是否在节点中。所有函数均使用@jit(nopython=True)编译,并且没有错误消息。

最佳答案

问题主要出现在np.append调用中。它创建一个新的更大的数组并为每次调用复制以前的内容,这显着增加了算法的复杂性(具有线性复杂性的算法在最坏的情况下可能会变成二次)。同样的事情也适用于 Numpy。

解决此问题的一个解决方案是使用列表,以便在其中附加许多 Numpy 数组,然后将所有这些数组连接到一个新的更大的 Numpy 数组中。

另一种解决方案是直接创建一个具有合适大小的大数组,然后通过将子数组复制到大数组来填充它。这个解决方案要快得多,但它需要提前知道最终数组的大小。当代码可以并行执行时,速度特别快。

另一种解决方案是使用前一个解决方案,并进行两个小更改:使用最大可能大小创建大数组,并且仅使用其实际填充的子集( View ) 由计算函数返回。该解决方案要求大数组的大小受到限制,并且该限制不能比返回 View 的实际平均大小大很多。这通常比使用列表更快,但需要更多内存。


注释和备注

小心:请注意,mad_schedule_col = ... 不会写入/修改传入参数的数组,它会创建一个新数组和变量 mad_schedule_col 设置为引用新创建的数组。如果你想改变输入数组,那么你需要写入它。如果您事先不知道大小,那么最好的方法就是返回修改后的 mad_schedule_col

另请注意,如果没有提供签名Numba 函数会延迟编译(当函数第一次执行时),并且编译时间可能会相当慢。提供签名会导致 Numba 急切地编译函数(定义函数时)。

请注意,如果 newl.shape[0] 如果很大,那么您的算法可能会受到内存限制。如果是这样,那么使用 Numba 不会快很多,因为内存已经饱和了。

关于python - numba 的表现(我做错了什么吗?),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70945414/

相关文章:

python - 更改二进制对角矩阵的 block 顺序

android - 加快文件数组的 MD5 检查

performance - jBoss 最大并发连接数

algorithm - 具有固定最大边长的平面图

python-2.7 - AttributeError: 'dict' 迭代集合字典时对象没有属性 (...)

python - 如何使用包含字段和值的列表过滤模型?

javascript - python 和谐吗?

python - 将 Numpy 数组保存为图像

javascript - 在 JavaScript 和 Node.js 中向 String 类添加函数会对性能产生什么影响?

algorithm - 如果形成循环的一条边已知,则列出图中形成循环的所有边的最快方法