python - 平衡 numpy 数组与过采样

标签 python arrays numpy

请帮我找到一种干净的方法来从现有数组中创建一个新数组。如果任何类的示例数小于该类中的最大示例数,则应该进行过采样。样本应该从原始数组中提取(随机或顺序都没有区别)

比方说,初始数组是这样的:

[  2,  29,  30,   1]
[  5,  50,  46,   0]
[  1,   7,  89,   1]
[  0,  10,  92,   9]
[  4,  11,   8,   1]
[  3,  92,   1,   0]

最后一列包含类:

classes = [ 0,  1,  9]

类的分布如下:

distrib = [2, 3, 1]

我需要的是创建一个新数组,其中所有类的样本数量相等,从原始数组中随机抽取,例如

[  5,  50,  46,   0]
[  3,  92,   1,   0]
[  5,  50,  46,   0] # one example added
[  2,  29,  30,   1]
[  1,   7,  89,   1]
[  4,  11,   8,   1]
[  0,  10,  92,   9]
[  0,  10,  92,   9] # two examples
[  0,  10,  92,   9] # added

最佳答案

下面的代码完成你所追求的:

a = np.array([[  2,  29,  30,   1],
              [  5,  50,  46,   0],
              [  1,   7,  89,   1],
              [  0,  10,  92,   9],
              [  4,  11,   8,   1],
              [  3,  92,   1,   0]])

unq, unq_idx = np.unique(a[:, -1], return_inverse=True)
unq_cnt = np.bincount(unq_idx)
cnt = np.max(unq_cnt)
out = np.empty((cnt*len(unq),) + a.shape[1:], a.dtype)
for j in xrange(len(unq)):
    indices = np.random.choice(np.where(unq_idx==j)[0], cnt)
    out[j*cnt:(j+1)*cnt] = a[indices]

>>> out
array([[ 5, 50, 46,  0],
       [ 5, 50, 46,  0],
       [ 5, 50, 46,  0],
       [ 1,  7, 89,  1],
       [ 4, 11,  8,  1],
       [ 2, 29, 30,  1],
       [ 0, 10, 92,  9],
       [ 0, 10, 92,  9],
       [ 0, 10, 92,  9]])

当 numpy 1.9 发布时,或者如果你从开发分支编译,那么前两行可以压缩成:

unq, unq_idx, unq_cnt = np.unique(a[:, -1], return_inverse=True,
                                  return_counts=True)

请注意,np.random.choice 的工作方式无法保证原始数组的所有行都会出现在输出数组中,如上例所示。如果需要,您可以执行以下操作:

unq, unq_idx = np.unique(a[:, -1], return_inverse=True)
unq_cnt = np.bincount(unq_idx)
cnt = np.max(unq_cnt)
out = np.empty((cnt*len(unq) - len(a),) + a.shape[1:], a.dtype)
slices = np.concatenate(([0], np.cumsum(cnt - unq_cnt)))
for j in xrange(len(unq)):
    indices = np.random.choice(np.where(unq_idx==j)[0], cnt - unq_cnt[j])
    out[slices[j]:slices[j+1]] = a[indices]
out = np.vstack((a, out))

>>> out
array([[ 2, 29, 30,  1],
       [ 5, 50, 46,  0],
       [ 1,  7, 89,  1],
       [ 0, 10, 92,  9],
       [ 4, 11,  8,  1],
       [ 3, 92,  1,  0],
       [ 5, 50, 46,  0],
       [ 0, 10, 92,  9],
       [ 0, 10, 92,  9]])

关于python - 平衡 numpy 数组与过采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/23391608/

相关文章:

Java 3d Arraylist 到 3d 数组

arrays - 如何通过 & :key as an argument to map instead of a block with ruby?

python - pd.Timestamp 与 np.datetime64 : are they interchangeable for selected uses?

python - Numpy 如何从非均匀分布中采样随机数?

python - 从日期列表中删除单词 DateTimeIndex

python - 如何 reshape 数据框以保留唯一信息?

python - reshape 只有一维的numpy数组

python - 类中的 Thread.__init__(self) 如何工作?

c++ - 指针符号和数组

python - nan、NaN和NAN有什么区别