scikit-learn - Scikit-learn 中 KNN 分类器中的网格搜索参数和交叉验证数据集

标签 scikit-learn cross-validation knn grid-search

我正在尝试使用 SciKit-Learn 执行我的第一个 KNN 分类器。我一直在关注用户指南和其他在线示例,但有一些事情我不确定。对于这篇文章,我们使用以下内容

X = 数据 Y = 目标

  1. 在我读过的大多数机器学习介绍页面中,似乎都说您需要训练集、验证集和测试集。据我了解,交叉验证允许您结合训练集和验证集来训练模型,然后您应该在测试集上对其进行测试以获得分数。然而,我在论文中看到,在很多情况下,您可以对整个数据集进行交叉验证,然后将 CV 分数报告为准确性。我知道在理想的世界中,您会希望对单独的数据进行测试,但如果这是合法的,我想对整个数据集进行交叉验证并报告这些分数

  2. 所以开始这个过程

我定义我的 KNN 分类器如下

knn = KNeighborsClassifier(algorithm = 'brute')

我使用搜索最好的 n_neighbors

clf = GridSearchCV(knn, parameters, cv=5)

现在如果我说

clf.fit(X,Y)

我可以使用检查最佳参数

clf.best_params_

然后我就可以获得分数

clf.score(X,Y)

但是 - 据我了解,这并没有交叉验证模型,因为它只给出 1 分?

如果我现在看到 clf.best_params_ = 14 我可以继续吗

knn2 = KNeighborsClassifier(n_neighbors = 14, algorithm='brute')
cross_val_score(knn2, X, Y, cv=5)

现在我知道数据已经过交叉验证,但我不知道使用 clf.fit 找到最佳参数然后将 cross_val_score 与新的 knn 模型一起使用是否合法?

  • 我知道“正确”的做法如下
  • 分割为X_train、X_test、Y_train、Y_test、 缩放训练集 -> 将变换应用于测试集

    knn = KNeighborsClassifier(algorithm = 'brute')
    clf = GridSearchCV(knn, parameters, cv=5)
    clf.fit(X_train,Y_train)
    clf.best_params_
    

    然后我就可以获得分数

    clf.score(X_test,Y_test)
    

    这样的话,分数是使用最佳参数计算的吗?

    <小时/>

    我希望这是有道理的。我一直在尝试在不发帖的情况下找到尽可能多的信息,但我已经到了这样的地步:我认为获得一些直接答案会更容易。

    在我的脑海中,我试图使用整个数据集获得一些交叉验证的分数,但也使用网格搜索(或类似的东西)来微调参数。

    最佳答案

    1. 是的,您可以对整个数据集进行 CV,这是可行的,但我仍然建议您至少将数据分成 2 组,一组用于 CV,一组用于测试。

    2. .score 函数应该根据 documentation 返回单个 float 值。这是给定 X,Y 上的最佳估计器(这是您通过拟合 GridSearchCV 获得的最佳得分估计器)的分数

    3. 如果您发现最佳参数是 14,那么您可以继续在模型中使用它,但如果您给它更多的参数,您应该设置所有参数。 (- 我这么说是因为你还没有给出你的参数列表)是的,再次检查你的简历是合法的,以防万一这个模型足够好。

    希望能让事情变得更清楚:)

    关于scikit-learn - Scikit-learn 中 KNN 分类器中的网格搜索参数和交叉验证数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40634726/

    相关文章:

    python - sklearn.model_selection.cross_val_score的score函数公式是什么?

    algorithm - KNN 算法中需要归一化

    r - “The format of predictions is incorrect”

    python - 使用 CV 获得较高的 RMSE 分数传达什么信息

    python - K 重 CV 的变体,其中 size(test_set) > N/K

    python - 使用 Scikit-learn (sklearn) 估算整个 DataFrame(所有列)而不迭代列

    machine-learning - 如果模型每次迭代都被丢弃,交叉验证的目的是什么

    machine-learning - 当 k=4 时 KNN 选择类标签

    python - KernelPCA 产生 NaN

    python - Scikit 学习 API xgboost 允许在线培训吗?