python - 如何为新的训练模型初始化coef_init和intercept_init?

标签 python machine-learning scikit-learn

如此处所指定,https://stackoverflow.com/a/35662770/5757129 ,我存储了第一个模型的系数和截距。后来,我将它们作为初始化程序传递给我的第二个 fit() ,如下所示,以便在旧模型之上学习新数据。

from sklearn import neighbors, linear_model
import numpy as np
import pickle
import os

def train_data():

    x1 = [[8, 9], [20, 22], [16, 18], [8,4]]
    y1 = [0, 1, 2, 3]

    #classes = np.arange(10)

    #sgd_clf = linear_model.SGDClassifier(learning_rate = 'constant', eta0 = 0.1, shuffle = False, n_iter = 1,warm_start=True)

    sgd_clf = linear_model.SGDClassifier(loss="hinge",max_iter=10000)

    sgd_clf.fit(x1,y1)

    coef = sgd_clf.coef_
    intercept = sgd_clf.intercept_

    return coef, intercept


def train_new_data(coefs,intercepts):

    x2 = [[18, 19],[234,897],[20, 122], [16, 118]]
    y2 = [4,5,6,7]

    sgd_clf1 = linear_model.SGDClassifier(loss="hinge",max_iter=10000)

    new_model = sgd_clf1.fit(x2,y2,coef_init=coefs,intercept_init=intercepts)

    return new_model


if __name__ == "__main__":

    coefs,intercepts= train_data()

    new_model = train_new_data(coefs,intercepts)

    print(new_model.predict([[16, 118]]))
    print(new_model.predict([[18, 19]]))
    print(new_model.predict([[8,9]]))
    print(new_model.predict([[20,22]]))

当我运行这个时,我得到了仅从 new_model 训练的标签。例如,print(new_model.predict([[8,9]]))必须将标签打印为 0 和 print(new_model.predict([[20,22]]))必须将标签打印为 1。但它打印从 4 到 7 匹配的标签。

我是否以错误的方式将 coef 和拦截从旧模型传递到新模型?

编辑:根据@vital_dml答案重新构建问题

最佳答案

我不知道为什么你需要将系数和截距从第一个模型传递到第二个模型,但是,你会得到这样的错误,因为你的第一个模型是针对 4 个类进行训练的 y1 = [0, 1, 2 , 3],而第二个有 2 个类 y2 = [4,5],这是有争议的。

根据scikit-learn documentation ,您的 linear_model.SGDClassifier() 返回:

coef_ : array, shape (1, n_features) if n_classes == 2 else (n_classes, n_features) - Weights assigned to the features.

intercept_ : array, shape (1,) if n_classes == 2 else (n_classes,) - Constants in decision function.

因此,在您的问题中,两个模型中的类和功能的数量必须相同。

无论如何,我鼓励您思考您真的需要这样做吗?也许你可以连接这些向量。

关于python - 如何为新的训练模型初始化coef_init和intercept_init?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49133188/

相关文章:

python - 为什么对测试数据调用 Transform() 会返回数据尚未拟合的错误?

python - sklearn SVM/SVC 始终为任何给定输入预测相同的类别

python - Scikit-learn (sklearn) PCA 在稀疏矩阵上抛出类型错误

python - 使用应用程序的 sql 文件夹内的 sql 文件将初始数据提供到 sql 表中在 django1.9 中不起作用

python - KeyError : "Couldn' t find enum caffe. EmitConstraint.EmitType"

python - 使用相关矩阵在大型稀疏矩阵上进行 PCA

用于机器学习算法的 Python csv 流

Python 西格玛和

python - 比较列表中的相邻变量并重新格式化输入

python - sklearn ShuffleSplit "__init__() got multiple values for argument ' n_splits '"错误