python - SciPy 共轭梯度优化在每次迭代后不调用回调方法

标签 python optimization machine-learning scipy theano

我按照教程here进行操作为了使用 theano 实现逻辑回归。上述教程使用SciPy的fmin_cg优化程序。上述函数的重要参数包括: f要最小化的对象/成本函数,x0用户提供了参数的初始猜测,fprime提供函数f的导数的函数在xcallback一个可选的用户提供的函数,在每次迭代后调用。

训练函数定义如下:

# creates a function that computes the average cost on the training set
def train_fn(theta_value):
    classifier.theta.set_value(theta_value, borrow=True)
    train_losses = [batch_cost(i * batch_size)
                    for i in xrange(n_train_batches)]
    return numpy.mean(train_losses)

上面的代码的作用基本上是遍历训练数据集中的所有小批量,为每个小批量计算平均批量成本(即应用于小批量中每个训练样本的成本函数的平均值)和平均值所有批处理的成本。可能值得指出的是,每个批处理的成本由 batch_cost 计算得出。 -- 一个 theano 函数。

对我来说,似乎callback函数被任意调用,而不是像 SciPy 中的文档声称的那样在每次迭代之后调用。

这是我修改 train_fn 后收到的输出和callback分别添加“train”和“callback”打印。

... training the model
train
train
train
callback
validation error 29.989583 %
train
callback
validation error 24.437500 %
train
callback
validation error 20.760417 %
train
callback
validation error 16.937500 %
train
callback
validation error 14.270833 %
train
callback
validation error 14.156250 %
train
callback
validation error 13.177083 %
train
callback
validation error 12.270833 %
train
train
callback
validation error 11.697917 %
train
callback
validation error 11.531250 %

我的问题是,因为每次调用train_fn确实是一个训练纪元,我如何改变行为,以便调用 callbacktrain_fn 之后调用?

最佳答案

每次调用train_fn不一定是一个训练周期。我不太确定 fmin_cg 是如何实现的,但总的来说,conjugate gradient methods每个最小化步骤可以多次调用成本或梯度函数。 (据我了解)有时需要找到相对于上一步所采取的共轭向量。1

因此,每次 fmin_cg 采取步骤时都会调用您的回调。如果您需要每次调用成本函数或梯度函数时都调用一个函数,则可以将调用放在相关函数中。

1. 编辑:至少当它们是非线性方法时,如 fmin_cg是。维基百科页面表明,普通共轭梯度(CG)方法可能不需要多次调用,但我认为它们不太适合优化非线性函数。我见过的 CG 代码(我猜想一定是非线性 CG)肯定涉及每一步至少一次线搜索。这肯定需要对梯度函数进行多次评估。

关于python - SciPy 共轭梯度优化在每次迭代后不调用回调方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30173144/

相关文章:

python - 为 CSV 阅读器引用 Python 字典中的键

assembly - x86/x64 asm 中的指令重新排序 - 使用最新 CPU 进行性能优化

c++ - #define 或类型函数?

machine-learning - 如何在不改变线性回归学习率的情况下加快学习速度

c# - 使用ID3算法、Accord.Net框架进行预测

python - sqlalchemy:如何为方言自定义标准类型(例如 DateTime() 参数绑定(bind)处理)?

python - SQLAlchemy:根据表字段查询自定义属性

python - 字符串索引在 python 中的一个代码位置必须是整数

阻塞操作的成本随着线程数量的增加而增加

python - Pandas 中的 loc 函数