c++ - 未找到 : FeedInputs: unable to find feed output TensorFlow

标签 c++ model tensorflow

我在本网站尝试使用 C++ 中的 Tensorflow 保存模型的示例: https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.ji310n4zo

效果很好。但它不保存变量 ab 的值,因为它只保存图形而不保存变量.我试图替换以下行:

tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)

saver.save(sess, 'models/graph', global_step=0)

当然是在创建保护程序对象之后。它不起作用并输出:

未找到:FeedInputs:无法找到 feed 输出 a

我检查了加载的节点,它们只是:

_SOURCE

_SINK

在 write_graph 函数中然后用 C++ 加载模型时,我加载了以下节点:

_SOURCE

_SINK

save/restore_slice_1/shape_and_slice

save/restore_slice_1/tensor_name

save/restore_slice/shape_and_slice

save/restore_slice/tensor_name

save/save/shapes_and_slices

save/save/tensor_names

save/Const

save/restore_slice_1

save/restore_slice

b

save/Assign_1

b/read

b/initial_value

b/Assign

a

save/Assign

save/restore_all

save/save

save/control_dependency

a/read

c

a/initial_value

a/Assign

init

Tensor

与 write_graph 创建的 1.9KB 相比,saver.save() 创建的图形文件也小得多,为 165B。

最佳答案

我不确定这是否是解决问题的最佳方法,但至少它解决了问题。

因为 write_graph 也可以存储常量的值,所以我在用 write_graph 函数写图之前在 python 中添加了以下代码:

for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    tf.assign(v, vc, name="assign_variables")

这会创建常量,在训练后存储变量的值,然后创建张量“assign_variables”以将它们分配给变量。现在,当您调用 write_graph 时,它会将变量的值存储在文件中。

唯一剩下的部分是在 C 代码中调用这些张量“assign_variables”,以确保您的变量被分配了存储在文件。这是一种方法:

      Status status = NewSession(SessionOptions(), &session);
      std::vector<tensorflow::Tensor> outputs;
      for(int i = 0;status.ok(); i++) {
        char name[100];
        if (i==0)
            sprintf(name, "assign_variables");
        else
            sprintf(name, "assign_variables_%d", i);

        status = session->Run({}, {name}, {}, &outputs);
      }

关于c++ - 未找到 : FeedInputs: unable to find feed output TensorFlow,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34796404/

相关文章:

c++ - 桥接 USB 主机到设备

c++ - 如何从 Swift 中的整数中获取特定位?

ubuntu - 在 ubuntu 中导入 tensorflow 时出错

python - 无法在 python 中导入 tensorfflow

c++ - iOS预编译头文件包含<boost/shared_ptr.hpp>时出错

c++ - 在 Boost Test 框架中测试 assert

python - 在 Google App Engine 中分离模型和请求处理程序

model - 如何在 Sencha 触摸模型中创建构造函数?

python - Django 查找在多个 OR 查询中匹配的字段

python - 如何计算给定输入和预期输出的 ctc 概率?