python-3.x - 决策树accuracy_score给出 "ValueError: Found input variables with inconsistent numbers of samples"

标签 python-3.x machine-learning scikit-learn decision-tree

我正在尝试使用给定数据创建决策树。但由于某种原因 accuracy_score 给出了

ValueError: Found input variables with inconsistent numbers of samples:

当我将训练数据拆分为验证(%20)和训练(%80)时。

这是我分割数据的方法:

from sklearn.utils import shuffle

from sklearn.model_selection import train_test_split

# stDt shuffled training set

stDt = shuffle(tDt) 

#divide shuffled training set to training and validation set

stDt, vtDt = train_test_split(stDt,train_size=0.8, shuffle=False)

print(tDt.shape)
print(stDt.shape)
print(vtDt.shape)

这是我训练数据的方法:

#attibutes and labels of training set

attributesT =  stDt.values

labelsT = stDt.label


# Train Decision tree classifiers
from sklearn.tree import DecisionTreeClassifier


dtree1 = DecisionTreeClassifier(min_samples_split = 1.0)

dtree2 = DecisionTreeClassifier(min_samples_split = 3)

dtree3 = DecisionTreeClassifier(min_samples_split = 5)



fited1 = dtree1.fit(attributesT,labelsT)

fited2 = dtree2.fit(attributesT,labelsT)

fited3 = dtree3.fit(attributesT,labelsT)

这是测试和准确性得分部分:

from sklearn.metrics import accuracy_score

ret1 = fited1.predict(stDt)

ret2 = fited2.predict(stDt)

ret3 = fited3.predict(stDt)

print(accuracy_score(vtDt.label,ret1))

最佳答案

您收到的错误是预期的,因为您试图将训练集 (ret1 = fited1.predict(stDt)) 生成的预测与标签进行比较您的验证集 (vtDt.label)。

以下是获得 fitted1 模型训练和验证准确性的正确方法(其他模型也类似):

# predictions on the training set:
ret1 = fitted1.predict(stDt)

# training accuracy:
accuracy_score(stDt.label,ret1)

# predictions on the validation set:
pred1 = fitted1.predict(vtDt)

# validation accuracy:
accuracy_score(vtDt.label,pred1)

关于python-3.x - 决策树accuracy_score给出 "ValueError: Found input variables with inconsistent numbers of samples",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52792149/

相关文章:

python - CNN 批处理不同大小的图像

machine-learning - 朴素贝叶斯文本分类在一个类别中失败。为什么?

python-3.x - sklearn 中 precision_recall_fscore_support 的输出是如何排序的?

python - 即使整个管道都安装了,管道中的 Sklearn 组件也没有安装?

python - 如何在 Python 3.0 中输出 Unicode 符号?

python-3.x - 有没有一种简单的方法可以手动迭代现有的 pandas groupby 对象?

python - 不支持的输入图像深度 : 'VDepth::contains(depth)' where 'depth' is 4 (CV_32S)

math - 计算分类准确率的最佳方法?

python - 将分类移至生产环境

python - 可以使用 exec 运行异步功能吗?