python - TensorFlow:有没有办法测量模型的 FLOPS?

标签 python tensorflow

我能得到的最接近的例子是在这个问题中找到的:https://github.com/tensorflow/tensorflow/issues/899

使用这个最小的可重现代码:

import tensorflow as tf
import tensorflow.python.framework.ops as ops 
g = tf.Graph()
with g.as_default():
  A = tf.Variable(tf.random_normal( [25,16] ))
  B = tf.Variable(tf.random_normal( [16,9] ))
  C = tf.matmul(A,B) # shape=[25,9]
for op in g.get_operations():
  flops = ops.get_stats_for_node_def(g, op.node_def, 'flops').value
  if flops is not None:
    print 'Flops should be ~',2*25*16*9
    print '25 x 25 x 9 would be',2*25*25*9 # ignores internal dim, repeats first
    print 'TF stats gives',flops

但是,返回的 FLOPS 始终为 None。有没有办法具体测量 FLOPS,尤其是 PB 文件?

最佳答案

我想以 Tobias Schnek 的回答为基础并回答原始问题:如何从 pb 文件中获取 FLOP。

使用 TensorFlow 1.6.0 运行 Tobias answer 的第一段代码

g = tf.Graph()
run_meta = tf.RunMetadata()
with g.as_default():
    A = tf.Variable(tf.random_normal([25,16]))
    B = tf.Variable(tf.random_normal([16,9]))
    C = tf.matmul(A,B)

    opts = tf.profiler.ProfileOptionBuilder.float_operation()    
    flops = tf.profiler.profile(g, run_meta=run_meta, cmd='op', options=opts)
    if flops is not None:
        print('Flops should be ~',2*25*16*9)
        print('TF stats gives',flops.total_float_ops)

我们得到以下输出:

Flops should be ~ 7200
TF stats gives 8288

那么,为什么我们得到 8288 而不是预期 结果 7200=2*25*16*9[a ]?答案在于张量 AB 的初始化方式。使用高斯分布初始化会花费一些 FLOP。更改 AB 的定义

    A = tf.Variable(initial_value=tf.zeros([25, 16]))
    B = tf.Variable(initial_value=tf.zeros([16, 9]))

给出预期的输出 7200

通常,网络的变量在其他方案中使用高斯分布进行初始化。大多数时候,我们对初始化 FLOP 不感兴趣,因为它们在初始化期间完成一次,并且不会在训练或推理期间发生。那么,如何在不考虑初始化 FLOP 的情况下获得 FLOP 的确切数量

pb 卡住图表。从 pb 文件计算 FLOP 实际上是 OP 的用例。

以下代码段说明了这一点:

import tensorflow as tf
from tensorflow.python.framework import graph_util

def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

# ***** (1) Create Graph *****
g = tf.Graph()
sess = tf.Session(graph=g)
with g.as_default():
    A = tf.Variable(initial_value=tf.random_normal([25, 16]))
    B = tf.Variable(initial_value=tf.random_normal([16, 9]))
    C = tf.matmul(A, B, name='output')
    sess.run(tf.global_variables_initializer())
    flops = tf.profiler.profile(g, options = tf.profiler.ProfileOptionBuilder.float_operation())
    print('FLOP before freezing', flops.total_float_ops)
# *****************************        

# ***** (2) freeze graph *****
output_graph_def = graph_util.convert_variables_to_constants(sess, g.as_graph_def(), ['output'])

with tf.gfile.GFile('graph.pb', "wb") as f:
    f.write(output_graph_def.SerializeToString())
# *****************************


# ***** (3) Load frozen graph *****
g2 = load_pb('./graph.pb')
with g2.as_default():
    flops = tf.profiler.profile(g2, options = tf.profiler.ProfileOptionBuilder.float_operation())
    print('FLOP after freezing', flops.total_float_ops)

输出

FLOP before freezing 8288
FLOP after freezing 7200

[a] 对于乘积 AB,矩阵乘法的 FLOP 通常为 mq(2p -1),其中 A[m, p]B [p, q] 但 TensorFlow 出于某种原因返回 2mpq。一个issue已打开了解原因。

关于python - TensorFlow:有没有办法测量模型的 FLOPS?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45085938/

相关文章:

tensorflow - 如何在 Tensorflow/Keras 中的 2 层之间创建循环连接?

python - Django Rest Framework 外键嵌套

osx 上的 python/pip 错误

python mysqldb游标消息列表两次显示错误

python - Plotly:如何自定义轴变换?

python - 在 Tensorflow 图中查找所需的占位符

python - Python 列表(元组)中每个元素有多少字节?

python - 如何从图像中分离出人脸?

opencv - 如何调整cvtColor灰度功能,使其保持RGB尺寸

python - 为什么我的线性回归得到的是 nan 值而不是学习?