tensorflow - 使用 tensorflow c++ api 运行 session 明显比使用 python 慢

标签 tensorflow

我正在尝试使用 tensorflow c++ api(仅限 CPU)运行 SqueezeDet。我已经卡住了 tensorflow 图并从 C++ 加载它。虽然在检测质量方面一切都很好,但性能比 python 慢得多。那可能是什么原因?

简化后,我的代码如下所示:

  int main (int argc, const char * argv[])
  {
    // Initializing graph 
    tensorflow::GraphDef graph_def;
    // Folder in which graph data is located
    string graph_file_name = "Model/graph.pb";
    // Loading graph 
    tensorflow::Status graph_loaded_status =  ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
    if (!graph_loaded_status.ok())
    {
      cout << graph_loaded_status.ToString() << endl;
      return 1;
    }
    unique_ptr<tensorflow::Session> session_sqdet(tensorflow::NewSession(tensorflow::SessionOptions()));
    tensorflow::Status session_create_status = session_sqdet->Create(graph_def);
    if (!session_create_status.ok())
    {
      cout << "Session create status: fail." << endl;
      return 1;
    }
    while ()
    {
      /* create & preprocess batch */

      session.Run({{ "image_input", input_tensor}, {"keep_prob", prob_tensor}}, {"probability/score", "bbox/trimming/bbox"}, {}, &final_output);

      /* do some postprocessing */
    }
  }

我尝试过的:

1) 使用优化标志 - 全部开启,没有警告。

2)使用batching:性能有所提升,但是python和C++的差距还是很大(运行session需要1s vs 2.4s,batch_size = 20)。

任何帮助将不胜感激。

最佳答案

我在这个问题上花了很多时间(大部分是因为我犯了愚蠢的错误),但我终于解决了它。现在我想在这里发布我的经验,因为它可能有用。

所以这些是我建议跟随面临相同问题的人的步骤(尽管其中一些非常明显):

0)正确进行分析!确保您使用的工具在多核/GPU/您拥有的任何设置中都可靠。

1) 检查 tensorflow 和所有相关包是否在所有优化下构建。

2)优化卡住后的图形。

3) 如果您在训练和推理期间使用不同的批次大小,请确保您已删除模型中的所有依赖项!请注意,否则您将不会收到错误消息或在结果质量方面的性能更差,您只会遇到神秘的减速!

关于tensorflow - 使用 tensorflow c++ api 运行 session 明显比使用 python 慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43892749/

相关文章:

tensorflow 将标签向量操作为 "multiple hot encoder"

tensorflow - 在本地计算机上反向图像搜索(用于图像重复)

python - Keras - 无法通过裁剪来限制输出

docker - 在docker上服务的TensorFlow对cuInit的调用失败:CUresult(-1)

python - 在 Keras 中为多标签文本分类神经网络创建一个带有 Attention 的 LSTM 层

tensorflow - CTC 丢失 InvalidArgumentError : sequence_length(b) <= time

optimization - 在 TensorFlow 中进行多 GPU 训练有什么优势?

python - 如何在 Tensorflow 中打印预测

python - 卡住模型并训练它

tensorflow - 将迭代器(来自 tf.data.Dataset)中的元素提供给 TensorFlow 模型的有效方法是什么?