python - 跳过 numpy 数组的第 n 个索引

标签 python numpy slice

为了进行 K 折验证,我想使用 slice 一个 numpy 数组,以便制作原始数组的 View ,但删除每个第 n 个元素。

例如:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

如果 n = 4 那么结果将是

[1, 2, 4, 5, 6, 8, 9]

注意:需要 numpy 是因为它被用于依赖关系固定的机器学习任务。

最佳答案

使用模数的方法 #1

a[np.mod(np.arange(a.size),4)!=0]

sample 运行-

In [255]: a
Out[255]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [256]: a[np.mod(np.arange(a.size),4)!=0]
Out[256]: array([1, 2, 3, 5, 6, 7, 9])

使用掩码的方法#2:需求作为 View

考虑到 View 需求,如果想法是节省内存,我们可以存储等效的 bool 数组,在 Linux 系统上占用的内存将减少 8 倍。因此,这种基于掩码的方法将是这样的 -

# Create mask
mask = np.ones(a.size, dtype=bool)
mask[::4] = 0

这是内存需求统计数据 -

In [311]: mask.itemsize
Out[311]: 1

In [312]: a.itemsize
Out[312]: 8

然后,我们可以使用 bool 索引作为 View -

In [313]: a
Out[313]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [314]: a[mask] = 10

In [315]: a
Out[315]: array([ 0, 10, 10, 10,  4, 10, 10, 10,  8, 10])

使用 NumPy array strides 的方法 #3:作为 View 的需求

您可以使用 np.lib.stride_tricks.as_strided给定输入数组的长度是 n 的倍数来创建这样的 View 。如果它不是倍数,它仍然可以工作,但不是安全的做法,因为我们会超出为输入数组分配的内存。请注意,这样创建的 View 将是 2D

因此,获得这种 View 的实现方式是 -

def skipped_view(a, n):
    s = a.strides[0]
    strided = np.lib.stride_tricks.as_strided
    return strided(a,shape=((a.size+n-1)//n,n),strides=(n*s,s))[:,1:]

sample 运行-

In [50]: a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) # Input array

In [51]: a_out = skipped_view(a, 4)

In [52]: a_out
Out[52]: 
array([[ 1,  2,  3],
       [ 5,  6,  7],
       [ 9, 10, 11]])

In [53]: a_out[:] = 100 # Let's prove output is a view indeed

In [54]: a
Out[54]: array([  0, 100, 100, 100,   4, 100, 100, 100,   8, 100, 100, 100])

关于python - 跳过 numpy 数组的第 n 个索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40929560/

相关文章:

python - 如何将二维查找表映射到数组(python)?

python - Numpy 的 'linalg.solve' 和 'linalg.lstsq' 没有给出与 Matlab 的 '\' 或 mldivide 相同的答案

go - byte[] channel 使用

python - 字符识别的最佳算法

opencv - NumPy/OpenCV 2 : how to enumerate all pixels from region?

python - 如何根据某些分组列取消透视 Pandas 数据框?

datetime - pandas 日期时间切片 : junkdf. ix ['2015-08-03' :'2015-08-06' ] 不起作用

xml - 如何从不在重复项中使用结束标记的嵌套 xml 中获取数据?

python - 如何使用 Python 将自定义参数添加到 URL 查询字符串?

python - 将标题/描述部分添加到数据框 Pandas Python