python - Tensorflow:从图中删除节点

标签 python tensorflow

我正在尝试从图中删除一些节点并将其保存在 .pb 中

只有需要的节点可以添加到新的 mod_graph_def图,但图在其他节点输入中仍有一些对已删除节点的引用,但我无法修改节点的输入:

def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # The problem that graph still have some references to deleted node in other nodes inputs
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        node.input = inp_names # TypeError: Can't set composite field

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())

最佳答案

def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Delete nodes
    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # Delete references to deleted nodes
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        del node.input[:]
        node.input.extend(inp_names)

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())

关于python - Tensorflow:从图中删除节点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56324534/

相关文章:

python - 线程无限期挂起

Python 数据管道和流式处理

python - 黑窗pygame

tensorflow - 如何在tensorflow keras中访问自定义层的递归层

python - 属性错误 : 'Dimension' object has no attribute 'log10' while using Keras Sequential Model. 适合

android - 使用子进程执行ADB命令

python - Plotly:如何为所有子图设置 xticks?

python - 运行pytorch时如何让cuda加载?

python - 拟合神经网络时出现 UnboundLocalError - TensorFlow 错误?

python - Python 2.7 中的奇怪异常.SystemExit