我有一个数据集(numpy 向量),其中包含 50 个类和 9000 个训练示例。
x_train=(9000,2048)
y_train=(9000,) # Classes are strings
classes=list(set(y_train))
我想构建一个子数据集,每个类都有 5 个示例
这意味着我得到了5*50=250
个训练样本。因此我的子数据集将采用这种形式:
sub_train_data=(250,2048)
sub_train_labels=(250,)
备注:我们从每个类别中随机抽取 5 个示例(类别总数 = 50)
谢谢
最佳答案
这是该问题的解决方案:
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
def balanced_sample_maker(X, y, sample_size, random_seed=42):
uniq_levels = np.unique(y)
uniq_counts = {level: sum(y == level) for level in uniq_levels}
if not random_seed is None:
np.random.seed(random_seed)
# find observation index of each class levels
groupby_levels = {}
for ii, level in enumerate(uniq_levels):
obs_idx = [idx for idx, val in enumerate(y) if val == level]
groupby_levels[level] = obs_idx
# oversampling on observations of each label
balanced_copy_idx = []
for gb_level, gb_idx in groupby_levels.items():
over_sample_idx = np.random.choice(gb_idx, size=sample_size, replace=True).tolist()
balanced_copy_idx+=over_sample_idx
np.random.shuffle(balanced_copy_idx)
data_train=X[balanced_copy_idx]
labels_train=y[balanced_copy_idx]
if ((len(data_train)) == (sample_size*len(uniq_levels))):
print('number of sampled example ', sample_size*len(uniq_levels), 'number of sample per class ', sample_size, ' #classes: ', len(list(set(uniq_levels))))
else:
print('number of samples is wrong ')
labels, values = zip(*Counter(labels_train).items())
print('number of classes ', len(list(set(labels_train))))
check = all(x == values[0] for x in values)
print(check)
if check == True:
print('Good all classes have the same number of examples')
else:
print('Repeat again your sampling your classes are not balanced')
indexes = np.arange(len(labels))
width = 0.5
plt.bar(indexes, values, width)
plt.xticks(indexes + width * 0.5, labels)
plt.show()
return data_train,labels_train
X_train,y_train=balanced_sample_maker(X,y,10)
关于python - 从每个类别标签中抽取 X 个样本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48425201/