python - TF 对象检测 Zoo 模型没有可训练变量?

标签 python tensorflow object-detection object-detection-api pre-trained-model

TF Objection Detection Zoo 中的模型有 meta+ckpt 文件、Frozen.pb 文件和 Saved_model 文件。

我尝试使用 meta+ckpt 文件进一步训练,并为特定张量提取一些权重以用于研究目的。我看到模型没有任何可训练的变量。

vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)

上面的代码片段给出了一个 [] 列表。我也尝试使用以下内容。

vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)

我再次得到一个 [] 列表。

这怎么可能?模型是否剥离了变量?还是 tf.Variable(trainable=False) ?我在哪里可以获得具有有效可训练变量的 meta+ckpt 文件。我专门看SSD+mobilnet机型

更新:

以下是我用于恢复的代码片段。它在一个类中,因为我正在为某些应用程序制作自定义工具。

def _importer(self):
    sess = tf.InteractiveSession()
    with sess.as_default():
        reader = tf.train.import_meta_graph(self.metafile,
                                            clear_devices=True)
        reader.restore(sess, self.ckptfile)

def _read_graph(self):
    sess = tf.get_default_session()
    with sess.as_default():
        vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        print(vars)

更新 2:

我还尝试了以下代码段。简约复古风格。

model_dir = 'ssd_mobilenet_v2/'

meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()

sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        reader = tf.train.import_meta_graph(meta,clear_devices=True)
        reader.restore(sess,ckpt)

        vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        for var in vari:
            print(var.name,"\n")

上面的代码片段还给出了[]变量列表

最佳答案

经过一些研究,您问题的最终答案是不,他们没有。这很明显,直到您意识到 saved_model 中的 variables 目录是空的。

对象检测模型zoo提供的checkpoint文件包含以下文件:

.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
    |-- saved_model.pb
    `-- variables

pipeline.config 是保存模型的配置文件,frozen_inference_graph.pb 是现成的推理。请注意 checkpointmodel.ckpt.data-00000-of-00001model.ckpt.metamodel.ckpt。 index 都对应checkpoint。 (Here 你可以找到一个很好的解释)

所以当你想得到可训练的变量时,唯一有用的就是saved_model目录。

Use SavedModel to save and load your model—variables, the graph, and the graph's metadata. This is a language-neutral, recoverable, hermetic serialization format that enables higher-level systems and tools to produce, consume, and transform TensorFlow models.

要恢复 SavedModel,您可以使用 API tf.saved_model.loader.load() , 这个 api 包含一个名为 tags 的参数,它指定了 MetaGraphDef 的类型。所以如果你想得到可训练的变量,你需要在调用api时指定tag_constants.TRAINING

我试图调用此 api 来恢复变量,但它给了我一个错误

MetaGraphDef associated with tags 'train' could not be found in SavedModel. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: saved_model_cli

所以我执行了这个 saved_model_cli 命令来检查 SavedModel 中可用的所有标签。

#from directory saved_model
saved_model_cli show --dir . --all

输出是

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
  ...

所以在这个SavedModel中没有标签train,只有serve。因此,此处的 SavedModel 仅用于 tensorflow 服务。这意味着当这些文件在创建时未使用标记 training 指定时,无法从这些文件中恢复训练变量。

P.S.:以下代码是我用来恢复 SavedModel 的代码。设置tag_constants.TRAINING时加载无法完成,设置tag_constants.SERVING时加载成功但变量为空。

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
  variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  print(variables)

P.P.S:我找到了创建 SavedModel 的脚本 here .可见创建SavedModel时确实没有train标签。

关于python - TF 对象检测 Zoo 模型没有可训练变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56547313/

相关文章:

Python:为什么我的 'if' 语句的数字计算不正确?

c++ - 在 iOS 上实现 TensorFlow Attention OCR

图像处理如何检测图像中的特定自定义形状

python - 使用密码解压缩 zip 文件失败 - Python 中的错误?

python - 为什么 value_counts 不显示所有存在的值?

python - Google 电子表格 api 不使用 OAuth2?

audio - 将实时音频数据馈送到移动设备上的 tensorflow

tensorflow - 使用 load_model 时,keras 内核初始化程序被错误调用

python - 使用opencv查找图像中的所有轮廓

python - 有效地从二维似然矩阵(numpy 数组)中提取局部最大值(坐标)