python - 训练元图的权重和偏差

标签 python tensorflow protocol-buffers protoc

我已成功将重新训练的 InceptionV3 NN 导出为 TensorFlow 元图。我已经成功地将这个 protobuf 读回 python,但我正在努力寻找一种方法来导出每一层的权重和偏差值,我假设这些值存储在元图 protobuf 中,以便在 TensorFlow 之外重新创建 nn。

我的工作流程是这样的:

Retrain final layer for new categories
Export meta graph tf.train.export_meta_graph(filename='model.meta')
Build python pb2.py using Protoc and meta_graph.proto
Load Protobuf:

import meta_graph_pb2
saved = meta_graph_pb2.CollectionDef()
with open('model.meta', 'rb') as f:
  saved.ParseFromString(f.read())

从这里我可以查看图表的大部分方面,例如节点名称等,但我认为我的经验不足导致很难找到访问每个相关层的权重和偏置值的正确方法。

最佳答案

MetaGraphDef 原型(prototype)实际上并不包含权重和偏差的值。相反,它提供了一种将 GraphDef 与存储在一个或多个检查点文件中的权重相关联的方法,这些检查点文件由 tf.train.Saver 编写. MetaGraphDef tutorial有更多细节,但大致结构如下:

  1. 在您的训练计划中,使用 tf.train.Saver 写出一个检查点。这也会将 MetaGraphDef 写入同一目录中的 .meta 文件。

    saver = tf.train.Saver(...)
    # ...
    saver.save(sess, "model")
    

    您应该在您的检查点目录中找到名为 model.metamodel-NNNN(对于一些整数 NNNN)的文件。

  2. 在另一个程序中,您可以导入刚刚创建的MetaGraphDef,并从检查点恢复。

    saver = tf.train.import_meta_graph("model.meta")
    saver.restore("model-NNNN")  # Or whatever checkpoint filename was written.
    

    如果要获取每个变量的值,可以(例如)在tf.all_variables() 集合中找到变量并将其传递给sess.run() 来获取它的值。例如,要打印所有变量的值,您可以执行以下操作:

    for var in tf.all_variables():
      print var.name, sess.run(var)
    

    您还可以过滤 tf.all_variables() 以查找您尝试从模型中提取的特定权重和偏差。

关于python - 训练元图的权重和偏差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39133285/

相关文章:

python - 学习率大于 0.001 会导致错误

python - 将 TFRecord 示例目录集成到模型训练中

python - 如何在 Tensorflow 或 Pytorch 中实现特定位置的卷积滤波器?

makefile - 尝试制作 Mosh 源代码时 Protocol Buffer 版本错误

python - 条形图的 Bokeh 对数刻度

从 numpy 数组中删除元素的 pythonic 方法

c++ - 用 C++ 解析二进制协议(protocol)的好库

python - 如何获取 protobuf 消息中定义的变量类型?

Python调用和控制台输出问题

python - 带否定的 re.sub 语法