arm - 从 TensorFlow 图中清除 dropout 操作

标签 arm tensorflow

我有一个经过训练的卡住图,我试图在 ARM 设备上运行它。基本上,我使用的是 contrib/pi_examples/label_image,但使用的是我的网络而不是 Inception。我的网络是用 dropout 训练的,现在给我带来了麻烦:

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs.  Registered kernels:
  device='CPU'; T in [DT_FLOAT]
  device='CPU'; T in [DT_INT32]
  device='GPU'; T in [DT_STRING]
  device='GPU'; T in [DT_BOOL]
  device='GPU'; T in [DT_INT32]
  device='GPU'; T in [DT_FLOAT]

 [[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]

我能看到的一种解决方案是构建这样一个包含相应操作的 TF 静态库。另一方面,从网络中消除 dropout ops 以使其更简单和更快可能是一个更好的主意。有没有办法做到这一点?

谢谢。

最佳答案

#!/usr/bin/env python2

import argparse

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2

def print_graph(input_graph):
    for node in input_graph.node:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

def strip(input_graph, drop_scope, input_before, output_after, pl_name):
    input_nodes = input_graph.node
    nodes_after_strip = []
    for node in input_nodes:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

        if node.name.startswith(drop_scope + '/'):
            continue

        if node.name == pl_name:
            continue

        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        if new_node.name == output_after:
            new_input = []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)
        nodes_after_strip.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_strip)
    return output_graph

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-graph', action='store', dest='input_graph')
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
    parser.add_argument('--output-graph', action='store', dest='output_graph')
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)

    args = parser.parse_args()

    input_graph = args.input_graph
    input_binary = args.input_binary
    output_graph = args.output_graph
    output_binary = args.output_binary

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)

    print "Before:"
    print_graph(input_graph_def)
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
    print "After:"
    print_graph(output_graph_def)

    if output_binary:
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    else:
        with tf.gfile.GFile(output_graph, "w") as f:
            f.write(text_format.MessageToString(output_graph_def))
    print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == "__main__":
    main()

关于arm - 从 TensorFlow 图中清除 dropout 操作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40358892/

相关文章:

assembly - FRDM-KL25z 组装延迟循环导致复位

c++ - 嵌入式编程开发板推荐

linux - 任何可用于 uclibc 的 backtrace 的移植?

c - ARM GCC 生成的函数序言

c++ - Tensorflow C++ 占位符初始化

machine-learning - keras 模型 fit_generator ValueError : Error when checking model target: expected cropping2d_4 to have 4 dimensions, 但得到形状为 (32, 1) 的数组

math - 平移等方差及其与卷积层和空间池化层的关系

linux - 访问外部设备时 ARM Cortex 上的 SIGBUS

machine-learning - TensorFlow 中动态改变权重

python - Tensorflow 中的生成序列