我有一个 csv 文件,该文件被读入数据帧。我根据一列的值将其分成训练和测试文件。
Let us say the column is called "category" and it has several category names as column values such as cat1,cat2,cat3 and so on which repeat more than once.
我需要拆分文件,以便每个类别名称在两个文件中至少出现一次。
到目前为止,我可以根据比率将文件分成两个。我尝试了很多选择,但这是迄今为止最好的一个。
def executeSplitData(self):
data = self.readCSV()
df = data
if self.column in data:
train, test = train_test_split(df, stratify = None, test_size=0.5)
self.writeTrainFile(train)
self.writeTestFile(test)
我不完全理解 test_train_split 中的分层选项。 请帮忙。谢谢
最佳答案
我尝试按照docs使用它并且无法让stratify
工作。
设置
from sklearn.cross_validation import train_test_split
import pandas as pd
import numpy as np
np.random.seed([3,1415])
p = np.arange(1, 5.) / np.arange(1, 5.).sum()
df = pd.DataFrame({'category': np.random.choice(('cat1', 'cat2', 'cat3', 'cat4'), (1000,), p=p),
'x': np.random.rand(1000), 'y': np.random.choice(range(2), (1000,))})
def get_freq(s):
return s.value_counts() / len(s)
print get_freq(df.category)
cat4 0.400
cat3 0.284
cat2 0.208
cat1 0.108
Name: category, dtype: float64
如果我尝试:
train, test = train_test_split(df, stratify=df.category, test_size=.5)
train, test = train_test_split(df, stratify=df.category.values, test_size=.5)
train, test = train_test_split(df, stratify=df.category.values.tolist(), test_size=.5)
全部返回:
TypeError: Invalid parameters passed:
文档说:
stratify : array-like or None (default is None)
我不明白为什么这行不通。
我决定解决以下问题:
def stratify_train_test(df, stratifyby, *args, **kwargs):
train, test = pd.DataFrame(), pd.DataFrame()
gb = df.groupby(stratifyby)
for k in gb.groups:
traink, testk = train_test_split(gb.get_group(k), *args, **kwargs)
train = pd.concat([train, traink])
test = pd.concat([test, testk])
return train, test
train, test = stratify_train_test(df, 'category', test_size=.5)
# this also works
# train, test = stratify_train_test(df, df.category, test_size=.5)
print get_freq(train.category)
print len(train)
Name: category, dtype: float64
cat4 0.400
cat3 0.284
cat2 0.208
cat1 0.108
Name: category, dtype: float64
500
print get_freq(test.category)
print len(test)
cat4 0.400
cat3 0.284
cat2 0.208
cat1 0.108
Name: category, dtype: float64
500
关于python - 将数据拆分为训练/测试文件,以便为这两个文件至少选取一个样本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37969945/