python - 组合图 : is there a TensorFlow import_graph_def equivalent for C++?

标签 python c++ pointers tensorflow merge

我需要使用自定义输入和输出层扩展导出的模型。我发现这可以轻松完成:

with tf.Graph().as_default() as g1: # actual model
    in1 = tf.placeholder(tf.float32,name="input")
    ou1 = tf.add(in1,2.0,name="output")
with tf.Graph().as_default() as g2: # model for the new output layer
    in2 = tf.placeholder(tf.float32,name="input")
    ou2 = tf.add(in2,2.0,name="output")

gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()

with tf.Graph().as_default() as g_combined: #merge together
    x = tf.placeholder(tf.float32, name="actual_input") # the new input layer

    # Import gdef_1, which performs f(x).
    # "input:0" and "output:0" are the names of tensors in gdef_1.
    y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                             return_elements=["output:0"])

    # Import gdef_2, which performs g(y)
    z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                             return_elements=["output:0"])

sess = tf.Session(graph=g_combined)

print "result is: ", sess.run(z, {"actual_input:0":5}) #result is: 9

这很好用。

然而,我需要提供一个指针作为网络输入,而不是传递任意形状的数据集。问题是,我在 python 中想不出任何解决方案(定义和传递指针),并且在使用 C++ Api 开发网络时,我找不到与 tf.import_graph_def 函数。

这在 C++ 中是否有不同的名称,或者是否有其他方法可以在 C++ 中合并两个图/模型?

谢谢你的建议

最佳答案

它不像在 Python 中那么容易。

你可以用这样的东西加载一个GraphDef:

#include <string>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/platform/env.h>

tensorflow::GraphDef graph;
std::string graphFileName = "...";
auto status = tensorflow::ReadBinaryProto(
    tensorflow::Env::Default(), graphFileName, &graph);
if (!status.ok()) { /* Error... */ }

然后你可以用它来创建 session :

#include <tensorflow/core/public/session.h>

tensorflow::Session *newSession;
auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &newSession);
if (!status.ok()) { /* Error... */ }
status = session->Create(graph);
if (!status.ok()) { /* Error... */ }

或者扩展现有图表:

status = session->Extend(graph);
if (!status.ok()) { /* Error... */ }

通过这种方式,您可以将多个 GraphDef 放入同一个图中。但是,没有额外的工具来提取特定节点,也没有避免名称冲突 - 您必须自己找到节点,并且必须确保 GraphDef 没有冲突的操作名称。例如,我使用此函数查找名称与给定正则表达式匹配的所有节点,并按名称排序:

#include <vector>
#include <regex>
#include <tensorflow/core/framework/node_def.pb.h>

std::vector<const tensorflow::NodeDef *> GetNodes(const tensorflow::GraphDef &graph, const std::regex &regex)
{
    std::vector<const tensorflow::NodeDef *> nodes;
    for (const auto &node : graph.node())
    {
        if (std::regex_match(node.name(), regex))
        {
            nodes.push_back(&node);
        }
    }
    std::sort(nodes.begin(), nodes.end(),
              [](const tensorflow::NodeDef *lhs, const tensorflow::NodeDef *rhs)
              {
                  return lhs->name() < rhs->name();
              });
    return nodes;
}

关于python - 组合图 : is there a TensorFlow import_graph_def equivalent for C++?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49490262/

相关文章:

python - 如何使用 ctypes 将列表列表的 Python 列表转换为 C 数组?

c++ - 将 vector<int> 插入 vector<vector<int>> 时出现堆溢出错误

c++ - 如何制作一个深const指针

C++双向链表 "delete tail"函数

c++ - C/C++ : Pointer Arithmetic

c - 创建二叉树的指针问题

Python + MySQLdb executemany

javascript - 在 Django 中处理没有表单的 Ajax 请求

python - 最后尝试 : Get current returning value

c++ - 我的程序在 Windows 机器上崩溃但在 Linux 上运行正常