python - 带有索引的 Scikit-learn train_test_split

标签 python scipy scikit-learn classification

使用train_test_split()时如何获取数据的原始索引?

我拥有的是以下

from sklearn.cross_validation import train_test_split
import numpy as np
data = np.reshape(np.randn(20),(10,2)) # 10 training examples
labels = np.random.randint(2, size=10) # 10 labels
x1, x2, y1, y2 = train_test_split(data, labels, size=0.2)

但这并没有给出原始数据的索引。 一种解决方法是将索引添加到数据(例如 data = [(i, d) for i, d in enumerate(data)]),然后将它们传递到 train_test_split然后再次展开。 有没有更清洁的解决方案?

最佳答案

您可以像 Julien 所说的那样使用 pandas 数据帧或系列,但如果您想将自己限制为 numpy,您可以传递一个额外的索引数组:

from sklearn.model_selection import train_test_split
import numpy as np
n_samples, n_features, n_classes = 10, 2, 2
data = np.random.randn(n_samples, n_features)  # 10 training examples
labels = np.random.randint(n_classes, size=n_samples)  # 10 labels
indices = np.arange(n_samples)
(
    data_train,
    data_test,
    labels_train,
    labels_test,
    indices_train,
    indices_test,
) = train_test_split(data, labels, indices, test_size=0.2)

关于python - 带有索引的 Scikit-learn train_test_split,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31521170/

相关文章:

python - 两个 Pandas 数据帧之间的快速斯 PIL 曼相关性

python - Scipy正常的pdf评估给出了矛盾的值

python - 为什么这个参数在sklearn的Pipeline中无效?

python - 值错误 : cannot reshape array of size 1048576 into shape (1024, 1024,3)

python - 以pythonic方式组合具有特定合并顺序的列表?

python - 段错误 - python C 扩展中的核心转储

python - Scipy 输出错误 :undefined symbol: sgegv_

python - 在 python 中模拟一个 int

scikit-learn - 如何调整投票分类器 (Sklearn) 中的权重

python - scikit-learn 中重复的FeatureUnion