python - 使用 sklearn 接口(interface)包裹模型

标签 python machine-learning scikit-learn wrapper gridsearchcv

使用 sklearn BaseEstimator 接口(interface)包装现有模型以兼容 gridsearchCV 的最佳方法是什么?我的模型既没有 set_param 也没有 get_params。我的方法如下:

class Wrapper(BaseEstimator):
    
    def __init__(param1, param2):
        self.model = ModelClass(param1, param2)
    
    def fit(data):
        self.model.fit(data)
        return self

    def predict(data):
        return self.model.predict(data)

    def get_params(self, deep=True): # ?
        return self.model.__dict__

    def set_params(self, **parameters): # ?, have I to recreate model?
        for parameter, value in parameters.items():
            setattr(self.model, parameter, value)
        return self
        

最佳答案

get_params方法中,您可以使用__dict__属性返回Wrapper实例的参数字典。这将允许 GridSearchCV 访问 Wrapper 实例的参数并使用它们进行超参数调整。

不要忘记在 __init__ 下添加 self.param1 = param1self.param2 = param2 以允许访问 get 和 set .

class Wrapper(BaseEstimator):
    
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2
        self.model = ModelClass(param1, param2)
    
    def fit(self, data):
        self.model.fit(data)
        return self

    def predict(self, data):
        return self.model.predict(data)
    
    def score(self, data):
        return self.model.score(data)

    def get_params(self, deep=True):
        return {'param1': self.param1, 'param2': self.param2}

    def set_params(self, **parameters):
        self.param1 = parameters.get('param1', self.param1)
        self.param2 = parameters.get('param2', self.param2)
        self.model = ModelClass(self.param1, self.param2)
        return self

使用GridsearchCV的示例:

from sklearn.model_selection import GridSearchCV

param_grid = {'param1': [1, 10, 100], 'param2': [0.01, 0.1, 1]}

model = Wrapper()
grid_search = GridSearchCV(estimator=model, param_grid=param_grid)
grid_search.fit(X_train, y_train)
test_score = grid_search.score(X_test, y_test)
print(f'Test score: {test_score:.2f}')

关于python - 使用 sklearn 接口(interface)包裹模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74852797/

相关文章:

python - 每次试验的 Psychopy 重置变量

Python 属性错误 : Object has no attribute in Unittest

python - 在python中获取当前控制台输出

machine-learning - 逻辑回归分类器的 Bootstrap 聚合(装袋)

python - 对一维 numpy 数组进行下采样

machine-learning - 如何判断一个数据集是否可以训练神经网络?

python - 如何使用 Opencv 和 Python 检测图像中的白色区域?

matplotlib - 如何使用 matplotlib 绘制非线性模型?

python - 在 Scikit-Learn 中使用非线性 SVM 时出错

python - 使用 scikit learn 训练逻辑回归以进行多类分类