python - numba @njit 更新一个大字典

标签 python jit numba

我尝试将 numba 用于需要以 (int, int) 元组作为键对非常大的 (10e6) 字典进行搜索的函数。

import numpy as np
from numba import njit

myarray = np.array([[0, 0],  # 0, 1
                    [0, 1],
                    [1, 1],  # 1, 2
                    [1, 2],  # 1, 3
                    [2, 2],
                    [1, 3]]
) # a lot of this with shape~(10e6, 2)

dict_with_tuples_key = {(0, 1): 1,
                        (3, 7): 1} # ~10e6 keys 

简化版是这样的

# @njit
def update_dict(dict_with_tuples_key, myarray):
    for line in myarray:
        i, j = line
        if (i, j) in dict_with_tuples_key:
            dict_with_tuples_key[(i, j)] += 1
        else:
            dict_with_tuples_key[(i, j)] = 1
    return dict_with_tuples_key

new_dict = update_dict(dict_with_tuples_key, myarray)
print new_dict

new_dict = update_dict2(dict_with_tuples_key, myarray)
# print new_dict
# {(0, 1): 2,   # +1 already in dict_with_tuples_key
#  (0, 0): 1,   # diag
#  (1, 1): 1,   # diag
#  (2, 2): 1,   # diag
#  (1, 2): 1,   # new from myarray
#  (1, 3): 1,   # new from myarray
#  (3, 7): 1 }

@njit 似乎不接受 dict 作为函数参数?

我想知道如何重写它,特别是 if (i, j) in dict_with_tuples_key 进行搜索的部分。

最佳答案

njit意味着该函数是在 nopython 模式下编译的。 dictlisttuple 是 python 对象,因此不受支持。不作为参数,也不在函数内部。

如果您的字典键完全不同,我会考虑使用二维 numpy 数组,其中第一个轴表示字典键元组的第一个索引,第二个轴表示第二个索引。然后你可以将其重写为:

from numba import njit
import numpy as np

@njit
def update_array(array, myarray):
    elements = myarray.shape[0]
    for i in range(elements):
        array[myarray[i][0]][myarray[i][1]] += 1 
    return array


myarray = np.array([[0, 0], [0, 1], [1, 1],
                    [1, 2], [2, 2], [1, 3]])

# Calculate the size of the numpy array that replaces the dict:
lens = np.max(myarray, axis=0) # Maximum values
array = np.zeros((lens[0]+1, lens[1]+1)) # Create an empty array to hold all indexes in myarray
update_array(array, myarray)

因为你已经用元组索引了你的字典,所以转换到索引数组的问题不会很大。

关于python - numba @njit 更新一个大字典,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35241097/

相关文章:

python - numba jitted 函数中集合的正确签名是什么?

python - Numpy 使用索引数组在另一个数组中累积一个数组

python - 解读 Django 源代码

python - Tkinter 文本突出显示标记如 : [B] and [\B]

python - 随机 int64 和 float64 数字

python - 在 numba 的 jitclass 中索引多维 numpy 数组

java - 对于 HotSpot JIT, "already compiled into a big method"是什么意思?

c++ - 是否有即时编译的正则表达式引擎?

python - 优化使用 numpy 创建 3d 矩阵

python - 当类的属性包含另一个类实例时如何指定 numba jitclass?