在下面的代码中,我尝试将该函数应用于 DataFrame 的每个单元格。运行时测量表明,当矩阵大小为 1000x1000 时,Numba 代码比纯 Python 慢 6-7 倍,当矩阵大小为 10 000x10 000 时,速度慢 2-3 倍。我还多次运行代码以确保编译时间不影响整体运行时间。我错过了什么?
import time
from numba import jit
import pandas as pd
vcf = pd.DataFrame(np.full(shape=(10_000,10_000), fill_value='./.'))
time1 = time.perf_counter()
@jit(cache=True)
def jit_func(x):
if x == './.':
return 1
else:
return 0
vcf.applymap(jit_func)
print('JIT', time.perf_counter() - time1)
time1 = time.perf_counter()
vcf.applymap(lambda x: 1 if x=='./.' else 0)
print('LAMBDA', time.perf_counter() - time1)
time1 = time.perf_counter()
def python_func(x):
if x == './.':
return 1
else:
return 0
vcf.applymap(python_func)
print('PYTHON', time.perf_counter() - time1)
输出:
JIT 464.7864613599959
LAMBDA 158.36754451994784
PYTHON 122.22150028299075
最佳答案
从 Python 调用 Numba 函数比纯 Python 函数具有更高的开销。这是因为 Numba 需要检查函数是否已编译并调用包装转换函数。转换函数旨在将纯 Python 类型转换为 native 类型,以便 Numba 不会对纯 Python 类型(绑定(bind)到 GIL)进行操作,因此它可以生成更高效的代码。对于像整数这样的简单类型,这个包装函数非常快。然而,目前对于字符串来说,包装函数非常昂贵。事实上,Numba 中的大多数字符串操作目前都非常慢。 AFAIK,目前还没有计划让它们很快变得更快(Numba 旨在用于数值计算,而不是字符串计算)。此外,基于字符串的函数编译速度也明显慢。
关键点是使用字节
来代替。这需要转换(例如使用 encode
),并且 bytes
只能(安全地)用于 ASCII 字符(即没有 unicode)。话虽如此,并不是说 unicode 操作通常很慢,CPython 对此也没有那么糟糕,因为它已经过很好的优化,可以有效地计算 unicode 字符串。
以下是支持上述解释的一些基准:
import numba as nb
# Eagerly compile the function so it is compiled before being called
@njit('void()')
def fn_nb():
pass
def fn_python():
pass
%timeit fn_nb() # 56.5 ns ± 0.573 ns/loop
%timeit fn_python() # 50.7 ns ± 0.386 ns/loop => faster than fn_python
# ----------------------------------------------------------------------------
@njit('void(int64)')
def fn_nb_with_int_param(useless_param):
pass
%timeit fn_nb_with_int_param(123) # 118 ns ± 0.525 ns/loop => parameters add overhead
# ----------------------------------------------------------------------------
@njit('void(int64[:])')
def fn_nb_with_arr_param(param):
pass
arr = np.array([], dtype=np.int64)
%timeit fn_nb_with_arr_param(arr) # 242 ns ± 2.01 ns/loop => arrays are more expensive
# ----------------------------------------------------------------------------
@njit('void(unicode_type)')
def fn_nb_with_str_param(param):
pass
s = ''
%timeit fn_nb_with_str_param(s) # 1.79 µs ± 7.51 ns/loop => MUCH slower with strings
# ----------------------------------------------------------------------------
@njit('int64(unicode_type)')
def fn_nb_with_str_and_body(param):
if param == './.':
return 1
else:
return 0
s = './123'
%timeit fn_nb_with_str_and_body(s) # 1.84 µs ± 11.2 ns/loop => just a bit slower
# ----------------------------------------------------------------------------
@njit # I do not know the string signature for this one
def fn_nb_with_bytes_params(param):
pass
s = b''
fn_nb_with_bytes_params(s) # Force Numba to compile the function
%timeit fn_nb_with_bytes_params(s) # 255 ns ± 2.23 ns/loop => much faster than strings
# ----------------------------------------------------------------------------
@njit
def fn_nb_with_bytes_and_body(param):
if param == './.':
return 1
else:
return 0
s = b'./123'
fn_nb_with_bytes_and_body(s)
%timeit fn_nb_with_bytes_and_body(s) # 259 ns ± 3.42 ns/loop => still fast!
就性能而言,如果可能的话,通常最好避免像瘟疫一样的字符串(尤其是 unicode)。有一些技巧可以做到这一点。一种方法是将具有许多相等字符串的字符串列转换为分类列(它们在内部存储为整数+用于标签的 int/string 表)。
如果您确实需要计算 unicode 字符串,那么 Cython 肯定更适合。
最后,调用 Numba 函数 100_000_000 次效率不高。事实上,即使在 C/C++ 这样的本地语言中,它的效率也很低(除非该函数是内联的)。最好获取给定列的数据并调用一次 native /编译/Python 函数。目标函数可以迭代所提供列的项目。 Pandas 目前存储字符串列的效率很低。 Pandas 开发人员计划在未来改进这一点,但目前,我们必须付出转换的成本(字符串列表到 native 类型),或者直接操作 CPython 对象的成本(并不比使用纯 - Python 函数)。
关于python - 为什么我的 pandas + numba 代码比 pandas + 纯 python 代码工作得更差?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76973012/