tensorflow - 如何设置微调的检查点

标签 tensorflow object-detection-api

我发现从 model_zoo 重新训练模型(ssd_mobilenetv2)时的损失在训练开始时非常大,而 validation_set 的准确性很好。训练记录如下:

日志不能来自经过训练的模型。我怀疑它不会加载检查点来进行微调。请帮助我如何对同一数据集上的训练模型进行微调。我根本没有修改网络结构。

我在 pipeline.config 中设置检查点路径如下: fine_tune_checkpoint:"//ssd_mobilenet_v2_coco_2018_03_29/model.ckpt" 如果我将 model_dir 设置为我的下载目录,它不会训练,因为 global_train_step 大于 max_step。然后放大max_step,可以看到从checkpoint恢复参数的日志。但它会遇到无法恢复某些参数的错误。 所以我将 model_dir 设置为一个空目录。可以正常训练,但是step0的loss会很大。而且验证结果很差

在 pipeline.config 中

fine_tune_checkpoint: "/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"
num_steps: 200000
fine_tune_checkpoint_type: "detection"

训练脚本

model_dir = '/ssd_mobilenet_v2_coco_2018_03_29/retrain0524

pipeline_config_path = '/ssd_mobilenet_v2_coco_2018_03_29/pipeline.config'

checkpoint_dir = '/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt'

num_train_steps = 300000
config = tf.estimator.RunConfig(model_dir=model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
    run_config=config,
    hparams=model_hparams.create_hparams(hparams_overrides),
    pipeline_config_path=pipeline_config_path,    
    sample_1_of_n_eval_examples=sample_1_of_n_eval_examples,
    sample_1_of_n_eval_on_train_examples=(sample_1_of_n_eval_on_train_examples))
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']

train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fns,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_on_train_data=False)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])

信息: tensorflow :损失 = 356.25497,步长 = 0 信息:tensorflow:global_step/秒:1.89768 信息:tensorflow:损失 = 11.221423,步长 = 100(52.700 秒) 信息:tensorflow:global_step/秒:2.21685 信息:tensorflow:损失 = 10.329516,步长 = 200(45.109 秒)

最佳答案

如果初始训练损失为 400,则模型很可能从检查点成功恢复,只是与检查点不完全相同。

Heressd模型的restore_map函数,注意即使你设置了fine_tune_checkpoint_type : detection,甚至提供了完全相同模型的checkpoint,仍然只有feature_extractor 范围内的变量被恢复。要从检查点恢复尽可能多的变量,您必须在配置文件中设置 load_all_detection_checkpoint_vars: true

def restore_map(self,
              fine_tune_checkpoint_type='detection',
              load_all_detection_checkpoint_vars=False):

if fine_tune_checkpoint_type not in ['detection', 'classification']:
  raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
      fine_tune_checkpoint_type))

if fine_tune_checkpoint_type == 'classification':
  return self._feature_extractor.restore_from_classification_checkpoint_fn(
      self._extract_features_scope)

if fine_tune_checkpoint_type == 'detection':
  variables_to_restore = {}
  for variable in tf.global_variables():
    var_name = variable.op.name
    if load_all_detection_checkpoint_vars:
      variables_to_restore[var_name] = variable
    else:
      if var_name.startswith(self._extract_features_scope):
        variables_to_restore[var_name] = variable

return variables_to_restore

关于tensorflow - 如何设置微调的检查点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56290419/

相关文章:

python - 编译 new_op 教程时出错(Tensorflow)

Tensorflow 高误报率和非最大抑制问题

amazon-web-services - Tensorflow 对象检测 API 的推理时间较慢

python-3.x - Tensorflow 对象检测 API - 验证丢失行为

python - 我们可以使用一维卷积进行图像分类吗?

tensorflow - TensorFlow 的 tf.nn.dynamic_rnn 运算符的输入张量是如何构造的?

python - 预期的二进制或 unicode 字符串,得到 nan - tensorflow/pandas

python-3.x - 如何从保存的模型生成 tflite?

tensorflow - tensorflow对象检测API中的超参数优化

python - 无效参数错误 : cannot compute MatMul as input #0(zero-based) was expected to be a float tensor but is a double tensor [Op:MatMul]