tensorflow - 如何计算用从 .pb 文件加载的图形定义的 tensorflow 模型中可训练参数的总数?

标签 tensorflow neural-network convolutional-neural-network

我想计算 tensorflow 模型中的参数。它与现有问题类似,如下所示。

How to count total number of trainable parameters in a tensorflow model?

但如果模型是用从 .pb 文件加载的图形定义的,则所有建议的答案都不起作用。基本上我用以下函数加载了图表。

def load_graph(model_file):

  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())

  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

一个示例是在 tensorflow-for-poets-2 中加载 frozen_graph.pb 文件以进行再训练。

https://github.com/googlecodelabs/tensorflow-for-poets-2

最佳答案

据我了解,GraphDef 没有足够的信息来描述Variables。正如解释的那样 here ,您将需要 MetaGraph,其中包含 GraphDefCollectionDef,后者是可以描述 Variables 的映射。所以下面的代码应该给我们正确的可训练变量计数。

导出元图:

import tensorflow as tf

a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, r'.\test')

导入元图并统计可训练参数的总数。

import tensorflow as tf

saver = tf.train.import_meta_graph('test.meta')

with tf.Session() as sess:
    saver.restore(sess, 'test')

total_parameters = 0
for variable in tf.trainable_variables():
    total_parameters += 1
print(total_parameters)

关于tensorflow - 如何计算用从 .pb 文件加载的图形定义的 tensorflow 模型中可训练参数的总数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50159209/

相关文章:

python - 在tensorflow 1.15中构建minpool层

tensorflow - 用神经网络逼近对数函数

c++ - dlib mlp::kernel_1a_c 类中的成员初始化

neural-network - Torchtext 属性错误 : 'Example' object has no attribute 'text_content'

neural-network - 我可以在 Keras 中使用带有卷积神经网络的矩形图像吗?

python - 将自定义阈值与 tf.contrib.learn.DNNClassifier 一起使用?

python - 如何创建 Cifar-10 子集?

tensorflow - 如何通过 Tensorflow conv2d 提供批量图像序列

python - 给定 CNN 的回归激活映射

python - 如何在 Tensorflow 运行时查看或保存一个张量?