python - 如何分割测试和训练数据,以保证每个类别中至少有一个

标签 python pandas machine-learning scikit-learn classification

我正在尝试对一些相当不平衡的数据进行分类。 然而,它的分类相当好。

为了准确评估效果,我必须将数据分为训练子集和测试子集。

现在我正在通过非常简单的措施来做到这一点:

import numpy as np
corpus = pandas.DataFrame(..., columns=["data","label"]) # My data, simplified
train_index = np.random.rand(len(corpus))>0.2
training_data = corpus[train_index]
test_data = corpus[np.logical_not(train_index)]

这很好也很简单,但是有些类很少出现: 在超过 50,000 个案例的语料库中,大约有 15 个出现次数少于 100 次,其中两个仅出现一次。

我想将我的数据语料库划分为测试和训练子集,以便:

  • 如果某个类别出现少于两次,则将其排除在外
  • 每个类(class)在测试和训练中至少出现一次
  • 测试和训练的划分是随机的

我可以拼凑一些东西来做到这一点, (可能最简单的方法是删除出现次数少于 2 次的东西),然后重新采样,直到吐出的每一侧都有),但我想知道是否已经存在一种干净的方法。

我不这么认为sklearn.cross_validation.train_test_split可以做到这一点,但它的存在表明 sklearn 可能具有这种功能。

最佳答案

以下满足将数据划分为测试和训练的 3 个条件:

#get rid of items with fewer than 2 occurrences.
corpus=corpus[corpus.groupby('label').label.transform(len)>1]

from sklearn.cross_validation import StratifiedShuffleSplit
sss=StratifiedShuffleSplit(corpus['label'].tolist(), 1, test_size=0.5, random_state=None)

train_index, test_index =list(*sss)
training_data=corpus.iloc[train_index]
test_data=corpus.iloc[test_index]

我已经使用以下虚构数据框测试了上面的代码:

#create random data with labels 0 to 39, then add 2 label case and one label case.     
corpus=pd.DataFrame({'data':np.random.randn(49998),'label':np.random.randint(40,size=49998)})
corpus.loc[49998]=[random.random(),40]
corpus.loc[49999]=[random.random(),40]
corpus.loc[50000]=[random.random(),41]

测试代码时会产生以下输出:

test_data[test_data['label']==40]
Out[110]: 
           data  label
49999  0.231547     40

training_data[training_data['label']==40]
Out[111]: 
           data  label
49998  0.253789     40

test_data[test_data['label']==41]
Out[112]: 
Empty DataFrame
Columns: [data, label]
Index: []

training_data[training_data['label']==41]
Out[113]: 
Empty DataFrame
Columns: [data, label]
Index: []

关于python - 如何分割测试和训练数据,以保证每个类别中至少有一个,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31238079/

相关文章:

python - 如何在Windows上的Anaconda环境中部署python程序?

python - 使用 Ramses 存储非结构化数据以便使用 Ramses-API 进行搜索?

python:表示包裹自身的方形网格(圆柱体)

python - 在 Pandas DataFrame 列上应用阈值

machine-learning - 使用矩阵 (NxN) 观测值创建离散隐马尔可夫模型?

python - 当顺序无关紧要时,删除列表列表中的任何重复项(例如,[1,2,3] 和 [1,3,2] 是重复集)?

python - 如何检索 pandas Series 对象中第 n 个元素的值?

arrays - 如何使用循环生成的数组作为数据帧中的列来创建数据帧

python - pandas_ml 中的 cross_validation 问题

python - roc_auc_score() 和 auc() 的结果不同