我正在尝试在我自己的数据集(这是我仅从中提取一类对象的 ADE20k 的子集)上训练一个 deeplab 模型。我想使用 mobilenet 作为主干并从预训练模型开始训练。因此,我从这里下载了预训练的权重:https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet (mobilenet_v2_1.4_224)。然后我修改了 data_segmentation.py 以包含我的数据集:
_ADE20K_DOORS_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 3530,
'val': 353,
},
num_classes=2,
ignore_label = 255,
)
_DATASETS_INFORMATION = {
'cityscapes': _CITYSCAPES_INFORMATION,
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
'ade20k': _ADE20K_INFORMATION,
'ade20k_doors': _ADE20K_DOORS_INFORMATION,
}
我修改 train.py 文件(更改标志的值)如下:
flags.DEFINE_boolean('initialize_last_layer', False,
'Initialize the last layer.')
flags.DEFINE_boolean('last_layers_contain_logits_only', True,
'Only consider logits as last layers or not.')
flags.DEFINE_boolean('fine_tune_batch_norm', False,
'Fine tune the batch norm parameters or not.')
我修改了 train_utils.py 文件,以便从要恢复的变量列表中排除 logits:
from deeplab.model import LOGITS_SCOPE_NAME
exclude_list = ['global_step', LOGITS_SCOPE_NAME, 'logits']
现在,当我尝试训练时,出现以下错误:
InvalidArgumentError (see above for traceback): Restoring from checkpoint
failed. This is most likely due to a mismatch between the current graph and
the graph from the checkpoint. Please ensure that you have not altered the
graph expected based on the checkpoint. Original error:
Assign requires shapes of both tensors to match. lhs shape= [576] rhs shape=
[816]
[[Node: save/Assign_50 = Assign[T=DT_FLOAT, _class=
["loc:@MobilenetV2/expanded_conv_11/expand/BatchNorm/beta"],
use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"]
(MobilenetV2/expanded_conv_11/expand/BatchNorm/beta, save/RestoreV2:50)]]
显然,预训练的检查点和我的模型之间存在不匹配。我错过了什么?你能帮我一下吗?非常感谢任何帮助。
为了训练,我使用以下命令:
python deeplab/train.py --logtostderr --training_number_of_steps=30000 --
train_split="train" --model_variant="mobilenet_v2" --output_stride=16 --
decoder_output_stride=4 --train_crop_size=513 --train_crop_size=513 --
train_batch_size=1 --dataset="ade20k_doors" --
tf_initial_checkpoint=deeplab/mobilenet/mobilenet_v2_1.4_224.ckpt --
train_logdir=deeplab/datasets/ADE20K/exp/train_on_train_set/train --
dataset_dir=deeplab/datasets/ADE20K/tfrecord
最佳答案
我通过更改预训练的权重摆脱了错误。它适用于这个模型: mobilenetv2_coco_voc_trainval
关于python - deeplab在自己的数据集上训练时从检查点恢复失败,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52094498/