我需要使用自定义输入和输出层扩展导出的模型。我发现这可以轻松完成:
with tf.Graph().as_default() as g1: # actual model
in1 = tf.placeholder(tf.float32,name="input")
ou1 = tf.add(in1,2.0,name="output")
with tf.Graph().as_default() as g2: # model for the new output layer
in2 = tf.placeholder(tf.float32,name="input")
ou2 = tf.add(in2,2.0,name="output")
gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()
with tf.Graph().as_default() as g_combined: #merge together
x = tf.placeholder(tf.float32, name="actual_input") # the new input layer
# Import gdef_1, which performs f(x).
# "input:0" and "output:0" are the names of tensors in gdef_1.
y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
return_elements=["output:0"])
# Import gdef_2, which performs g(y)
z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
return_elements=["output:0"])
sess = tf.Session(graph=g_combined)
print "result is: ", sess.run(z, {"actual_input:0":5}) #result is: 9
这很好用。
然而,我需要提供一个指针作为网络输入,而不是传递任意形状的数据集。问题是,我在 python 中想不出任何解决方案(定义和传递指针),并且在使用 C++ Api
开发网络时,我找不到与 tf.import_graph_def
函数。
这在 C++ 中是否有不同的名称,或者是否有其他方法可以在 C++ 中合并两个图/模型?
谢谢你的建议
最佳答案
它不像在 Python 中那么容易。
你可以用这样的东西加载一个GraphDef
:
#include <string>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/platform/env.h>
tensorflow::GraphDef graph;
std::string graphFileName = "...";
auto status = tensorflow::ReadBinaryProto(
tensorflow::Env::Default(), graphFileName, &graph);
if (!status.ok()) { /* Error... */ }
然后你可以用它来创建 session :
#include <tensorflow/core/public/session.h>
tensorflow::Session *newSession;
auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &newSession);
if (!status.ok()) { /* Error... */ }
status = session->Create(graph);
if (!status.ok()) { /* Error... */ }
或者扩展现有图表:
status = session->Extend(graph);
if (!status.ok()) { /* Error... */ }
通过这种方式,您可以将多个 GraphDef
放入同一个图中。但是,没有额外的工具来提取特定节点,也没有避免名称冲突 - 您必须自己找到节点,并且必须确保 GraphDef
没有冲突的操作名称。例如,我使用此函数查找名称与给定正则表达式匹配的所有节点,并按名称排序:
#include <vector>
#include <regex>
#include <tensorflow/core/framework/node_def.pb.h>
std::vector<const tensorflow::NodeDef *> GetNodes(const tensorflow::GraphDef &graph, const std::regex ®ex)
{
std::vector<const tensorflow::NodeDef *> nodes;
for (const auto &node : graph.node())
{
if (std::regex_match(node.name(), regex))
{
nodes.push_back(&node);
}
}
std::sort(nodes.begin(), nodes.end(),
[](const tensorflow::NodeDef *lhs, const tensorflow::NodeDef *rhs)
{
return lhs->name() < rhs->name();
});
return nodes;
}
关于python - 组合图 : is there a TensorFlow import_graph_def equivalent for C++?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49490262/