python - 如何从给定模型中获取 Graph(或 GraphDef)?

标签 python c++ tensorflow keras model

我有一个使用 Tensorflow 2 和 Keras 定义的大型模型。 该模型在 Python 中运行良好。现在,我想将它导入到 C++ 项目中。

在我的 C++ 项目中,我使用 TF_GraphImportGraphDef 函数。 如果我使用以下代码准备 *.pb 文件,效果会很好:

    with open('load_model.pb', 'wb') as f:
        f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())

我已经在使用 Tensorflow 1(使用 tf.compat.v1.* 函数)编写的简单网络上尝试了这段代码。它运行良好。

现在我想将我的大模型(开头提到的,使用Tensorflow 2编写)导出到C++项目中。为此,我需要从我的模型中获取一个 GraphGraphDef 对象。问题是:如何做到这一点?我没有找到任何属性或函数来获取它。

我还尝试使用 tf.saved_model.save(model, 'model') 来保存整个模型。它生成一个包含不同文件的目录,包括 saved_model.pb 文件。不幸的是,当我尝试使用 TF_GraphImportGraphDef 函数在 C++ 中加载此文件时,程序抛出异常。

最佳答案

tf.saved_model.save 生成的 Protocol Buffer 文件不包含 GraphDef消息,而是一个 SavedModel .你可以 traverse that SavedModel in Python获取其中的嵌入图,但这不会立即用作卡住图,因此正确处理可能很困难。取而代之的是,C++ API 现在包含一个 LoadSavedModel。允许您从目录加载整个已保存模型的调用。它应该看起来像这样:

#include <iostream>
#include <...>  // Add necessary TF include directives

using namespace std;
using namespace tensorflow;

int main()
{
    // Path to saved model directory
    const string export_dir = "...";
    // Load model
    Status s;
    SavedModelBundle bundle;
    SessionOptions session_options;
    RunOptions run_options;
    s = LoadSavedModel(session_options, run_options, export_dir,
                       // default "serve" tag set by tf.saved_model.save
                       {"serve"}, &bundle));
    if (!.ok())
    {
        cerr << "Could not load model: " << s.error_message() << endl;
        return -1;
    }
    // Model is loaded
    // ...
    return 0;
}

从这里开始,您可以做不同的事情。也许您最愿意使用 FreezeSavedModel 将保存的模型转换为卡住图。 ,这应该让您可以像以前一样做事:

GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
s = FreezeSavedModel(bundle, &frozen_graph_def,
                     &inputs, &outputs));
if (!s.ok())
{
    cerr << "Could not freeze model: " << s.error_message() << endl;
    return -1;
}

否则,您可以直接使用保存的模型对象:

// Default "serving_default" signature name set by tf.saved_model_save
const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");
// Get input and output names (different from layer names)
// Key is input and output layer names
const string input_name = signature_def.inputs().at("my_input").name();
const string output_name = signature_def.inputs().at("my_output").name();
// Run model
Tensor input = ...;
std::vector<Tensor> outputs;
s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
if (!s.ok())
{
    cerr << "Error running model: " << s.error_message() << endl;
    return -1;
}
// Get result
Tensor& output = outputs[0];

关于python - 如何从给定模型中获取 Graph(或 GraphDef)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63181951/

相关文章:

android - 由于已弃用的 Jack 工具链,构建 TensorFlow Android 演示应用程序的问题

python - 在 numpy 数组中查找负值和正值

python - IPython 笔记本 : using quotes (') and dollar($) to pass string variable as argument to python script in command-line(!) | eg: abc.py -arg1 ' $var1'

python - 无法从项目目录导入 flask ,但可以在其他任何地方使用

python - 如何从修改后的字符串列表创建 gtk.STOCK_* 按钮?

c++ - OpenGL 纹理不通过 GLSL 渲染

c++ - 优化分配的模板技巧

tensorflow - 模型似乎过度拟合 Optimizer.minimize() 但不是 tf.contrib.layers.optimize_loss()

c++ - 有人在使用命名的 boolean 运算符吗?

tensorflow - 让 keras LSTM 层接受两个输入?