我是机器学习新手,第一次尝试 Sklearn。我有两个数据框,一个包含用于训练逻辑回归模型的数据(具有 10 倍交叉验证),另一个用于使用该模型预测类别 ('0,1')。 到目前为止,这是我的代码,使用了我在 Sklearn 文档和 Web 上找到的一些教程:
import pandas as pd
import numpy as np
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.preprocessing import normalize
from sklearn.preprocessing import scale
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn import metrics
# Import dataframe with training data
df = pd.read_csv('summary_44.csv')
cols = df.columns.drop('num_class') # Data to use (num_class is the column with the classes)
# Import dataframe with data to predict
df_pred = pd.read_csv('new_predictions.csv')
# Scores
df_data = df.ix[:,:-1].values
# Target
df_target = df.ix[:,-1].values
# Values to predict
df_test = df_pred.ix[:,:-1].values
# Scores' names
df_data_names = cols.values
# Scaling
X, X_pred, y = scale(df_data), scale(df_test), df_target
# Define number of folds
kf = KFold(n_splits=10)
kf.get_n_splits(X) # returns the number of splitting iterations in the cross-validator
# Logistic regression normalizing variables
LogReg = LogisticRegression()
# 10-fold cross-validation
scores = [LogReg.fit(X[train], y[train]).score(X[test], y[test]) for train, test in kf.split(X)]
print scores
# Predict new
novel = LogReg.predict(X_pred)
这是实现逻辑回归的正确方法吗? 我知道应该在交叉验证后使用 fit() 方法,以便训练模型并将其用于预测。然而,由于我在列表理解中调用了 fit() ,所以我真的不知道我的模型是否“适合”并可用于进行预测。
最佳答案
总体来说一切都很好,但是存在一些问题。
缩放
X, X_pred, y = scale(df_data), scale(df_test), df_target
您独立扩展训练和测试数据,这是不正确的。两个数据集必须使用相同的缩放器进行缩放。 “Scale”是一个简单的函数,但最好使用其他函数,例如 StandardScaler。
scaler = StandardScaler()
scaler.fit(df_data)
X = scaler.transform(df_data)
X_pred = scaler.transform(df_test)
交叉验证和预测。 你的代码如何工作?将数据拆分 10 次,分为训练集和保留集;在训练集上拟合模型 10 次,并计算保留集上的分数。通过这种方式,您可以获得交叉验证分数,但模型仅适合部分数据。因此,最好在整个数据集上拟合模型,然后进行预测:
LogReg.fit(X, y) novel = LogReg.predict(X_pred)
我想注意到,有一些高级技术,例如堆栈和提升,但如果您学习使用 sklearn,那么最好坚持基础知识。
关于python - 逻辑回归 sklearn - 训练和应用模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47361513/