python - 了解 Optuna 中的中间值和修剪

标签 python machine-learning deep-learning hyperparameters optuna

我只是想了解有关中间步骤实际上是什么以及如何使用修剪的更多信息(如果您使用教程部分中未包含的其他机器学习库,例如)XGB、Pytorch 等。

例如:

X, y = load_iris(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
classes = np.unique(y)
n_train_iter = 100

def objective(trial):
    global num_pruned
    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)
    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        intermediate_value = clf.score(X_valid, y_valid)
        trial.report(intermediate_value, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return clf.score(X_valid, y_valid)


study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.HyperbandPruner(
        min_resource=1, max_resource=n_train_iter, reduction_factor=3
    ),
)
study.optimize(objective, n_trials=30)

for step in range() 部分的意义是什么?这样做是否只会使优化花费更多时间,并且不会为循环中的每个步骤产生相同的结果?

我真的很想弄清楚是否需要for step in range(),并且每次您希望使用修剪时都需要它吗?

最佳答案

通过一次完整的训练数据集即可完成基本模型的创建。但有些模型仍然可以通过在相同的训练数据集上重新训练来改进(提高准确性)。

为了确保我们在这里不浪费资源,我们将在每一步之后通过 intermediate_score 使用验证数据集检查准确性,如果准确性提高,如果没有提高,我们会修剪整个试验,跳过其他步骤。然后我们进行下一次试验,询问另一个 alpha 值 - 我们试图确定在验证数据集上具有最高准确性的超参数。

对于其他库来说,只需问自己我们想要的模型是什么,准确性肯定是衡量模型能力的一个很好的标准。还可以有其他的。

示例 optuna 修剪,我希望模型继续重新训练,但仅限于我的特定条件。如果中间值无法击败我的 best_accuracy,并且步数已经超过最大迭代的一半,则修剪此试验。

best_accuracy = 0.0


def objective(trial):
    global best_accuracy

    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)

    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        if step > n_train_iter//2:
            intermediate_value = clf.score(X_valid, y_valid)

            if intermediate_value < best_accuracy:
                raise optuna.TrialPruned()

    best_accuracy = clf.score(X_valid, y_valid)

    return best_accuracy

Optuna 在 https://optuna.readthedocs.io/en/stable/reference/pruners.html 有专门的修枝剪

关于python - 了解 Optuna 中的中间值和修剪,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69990009/

相关文章:

python - 在 Python 中将单词分类

python - 如何将参数传递给通过字典查找调用的这些函数

python - 我可以使用带有 pandas 数据框的散点图绘制回归线并显示参数吗?

python - Python中的分布式多处理池

python - 为什么三元运算符比 .get 更快?

machine-learning - 时间序列预测(DeepAR): Prediction results seem to have basic flaw

machine-learning - CNN不同时间的feed数据

python - 如何正确使用 scipy.optimize.minimize 来返回整数的函数

machine-learning - 如何制作非静态的Caffe网络架构?

python - 我应该卡住哪些层来微调 keras 上的 resnet 模型?