我是answering a question关于pandas
interpolation
方法。 OP 希望仅在连续 np.nan 的数量为 1 的情况下使用插值。 interpolate
的 limit=1
选项将插入第一个 np.nan
并在那里停止。 OP 希望能够知道实际上有多个 np.nan
,甚至不关心第一个。
我将其归结为仅按原样执行插值
,并在事后屏蔽连续的np.nan
。
问题是:什么是通用解决方案,它采用一维数组 a
和整数 x
并生成一个 bool 掩码,其中 x 的位置为 False或多个连续的np.nan
考虑一维数组a
a = np.array([1, np.nan, np.nan, np.nan, 1, np.nan, 1, 1, np.nan, np.nan, 1, 1])
我希望对于x = 2
,掩码看起来像这样
# assume 1 for True and 0 for False
# a is [ 1. nan nan nan 1. nan 1. 1. nan nan 1. 1.]
# mask [ 1. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 1.]
# ^
# |
# Notice that this is not masked because there is only one np.nan
我希望对于x = 3
,掩码看起来像这样
# assume 1 for True and 0 for False
# a is [ 1. nan nan nan 1. nan 1. 1. nan nan 1. 1.]
# mask [ 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
# ^ ^ ^
# | | |
# Notice that this is not masked because there is less than 3 np.nan's
我期待学习其他人的想法;-)
最佳答案
我真的很喜欢numba对于这样容易掌握但很难“numpyfy”的问题!尽管该包对于大多数库来说可能有点太重了,但它允许编写类似“python”的函数,而不会损失太多速度:
import numpy as np
import numba as nb
import math
@nb.njit
def mask_nan_if_consecutive(arr, limit): # I'm not good at function names :(
result = np.ones_like(arr)
cnt = 0
for idx in range(len(arr)):
if math.isnan(arr[idx]):
cnt += 1
# If we just reached the limit we need to backtrack,
# otherwise just mask current.
if cnt == limit:
for subidx in range(idx-limit+1, idx+1):
result[subidx] = 0
elif cnt > limit:
result[idx] = 0
else:
cnt = 0
return result
至少如果你使用纯Python,这应该很容易理解并且应该可以工作:
>>> a = np.array([1, np.nan, np.nan, np.nan, 1, np.nan, 1, 1, np.nan, np.nan, 1, 1])
>>> mask_nan_if_consecutive(a, 1)
array([ 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1.])
>>> mask_nan_if_consecutive(a, 2)
array([ 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1.])
>>> mask_nan_if_consecutive(a, 3)
array([ 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.])
>>> mask_nan_if_consecutive(a, 4)
array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
但是 @nb.njit
-decorator 真正好的一点是,这个函数会很快:
a = np.array([1, np.nan, np.nan, np.nan, 1, np.nan, 1, 1, np.nan, np.nan, 1, 1])
i = 2
res1 = mask_nan_if_consecutive(a, i)
res2 = mask_knans(a, i)
np.testing.assert_array_equal(res1, res2)
%timeit mask_nan_if_consecutive(a, i) # 100000 loops, best of 3: 6.03 µs per loop
%timeit mask_knans(a, i) # 1000 loops, best of 3: 302 µs per loop
因此,对于短数组来说,速度大约快了 50 倍,尽管差异变小,但对于较长数组来说,速度仍然更快:
a = np.array([1, np.nan, np.nan, np.nan, 1, np.nan, 1, 1, np.nan, np.nan, 1, 1]*100000)
i = 2
%timeit mask_nan_if_consecutive(a, i) # 10 loops, best of 3: 20.9 ms per loop
%timeit mask_knans(a, i) # 10 loops, best of 3: 154 ms per loop
关于python - 仅当连续 nan 超过 x 时才进行掩码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43082316/