python - CNTK python api - 继续分类器训练

标签 python cntk

这实际上不是这个问题... CNTK python api - continue training a model它们是相关的,但并不相同。

我训练了一个模型 1500 个 epoch,平均损失约为 67% 左右。然后我想继续训练,我的编码如下:

def Create_Trainer(train_reader, minibatch_size, epoch_size, checkpoint_path=None, distributed_after=INFINITE_SAMPLES):
#Create Model with Params
lr_per_minibatch = learning_rate_schedule(
    [0.01] * 10 + [0.003] * 10 + [0.001], UnitType.minibatch, epoch_size)
momentum_time_constant = momentum_as_time_constant_schedule(
    -minibatch_size / np.log(0.9))
l2_reg_weight = 0.0001
input_var = input_variable((num_channels, image_height, image_width))
label_var = input_variable((num_classes))
feature_scale = 1.0 / 256.0
input_var_norm = element_times(feature_scale, input_var)
z = create_model(input_var_norm, num_classes)
#Create Error Functions
if(checkpoint_path):
    print('Loaded Checkpoint!')
    z.load_model(checkpoint_path)
ce = cross_entropy_with_softmax(z, label_var)
pe = classification_error(z, label_var)    

#Create Learner    
learner = momentum_sgd(z.parameters,
                        lr=lr_per_minibatch, momentum=momentum_time_constant,
                        l2_regularization_weight=l2_reg_weight)
if(distributed_after != INFINITE_SAMPLES):
    learner = distributed.data_parallel_distributed_learner(
        learner = learner,
        num_quantization_bits = 1,
        distributed_after = distributed_after
    )
input_map = {
    input_var: train_reader.streams.features,
    label_var: train_reader.streams.labels
}
return Trainer(z, ce, pe, learner), input_map

注意代码行:if(checkpoint_path):大约在中间。

我加载之前训练中的 .dnn 文件,该文件是通过此函数保存的...

if current_epoch % checkpoint_frequency == 0:
            trainer.save_checkpoint(os.path.join(checkpoint_path + "_{}.dnn".format(current_epoch)))

这实际上会生成一个 .dnn 和一个 .dnn.ckp 文件。显然我只在load_model中加载.dnn文件。

当我重新开始训练并加载模型时,看起来好像它可能正在加载网络架构,但可能不是权重?这样做的正确方法是什么?

谢谢!

最佳答案

您需要使用 trainer.restore_from_checkpoint 来代替,这应该重新创建训练器和学习器。

很快就会有一个训练类(class),它将允许以简单的方式无缝恢复,照顾训练器/小批量/分布式状态。

一件重要的事情:在 python 脚本中,创建检查点和从检查点恢复时的网络结构和创建节点的顺序必须相同。

关于python - CNTK python api - 继续分类器训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41617326/

相关文章:

python - 在 doctest 中包含原始制表 rune 字字符

python - 将 CNTK virtualenv 添加到 Visual Studio Python 项目

python - 如何在 CNTK Sequential 中添加两层

python - CNTK:定义自定义损失函数(Sørensen-Dice 系数)

python - CNTK python api - 继续训练模型

cntk - 如何验证GPU的使用情况?

python - 员工标签验证错误的 odoo 分配请求?

python - 在python中按元素对数据结构进行排序

具有递归的 Python C API - 段错误

python - 如何在 pandas 数据框的子集中搜索出现值的行