tensorflow - 如何使用 C API 遍历 Tensorflow 图?

标签 tensorflow machine-learning

下面的一个小程序创建了一个简单的 tf 图。我需要遍历图表,并打印有关节点的信息。

假设每个图都有一个根(或区分节点)是否正确?我相信这个图有 3 个节点,而且我听说边是张量。

#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include"tensorflow/c/c_api.h"

TF_Graph* g;
TF_Status* s;

#define CHECK_OK(x) if(TF_OK != TF_GetCode(s))return printf("%s\n",TF_Message(s)),(void*)0

TF_Tensor* FloatTensor2x2(const float* values) {
  const int64_t dims[2] = {2, 2};
  TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
  memcpy(TF_TensorData(t), values, sizeof(float) * 4);
  return t;
}

TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, const float* values, const char* name) {
  TF_Tensor* tensor=FloatTensor2x2(values);
  TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
  TF_SetAttrTensor(desc, "value", tensor, s);
  if (TF_GetCode(s) != TF_OK) return 0;
  TF_SetAttrType(desc, "dtype", TF_FLOAT);
  TF_Operation* op = TF_FinishOperation(desc, s);
  CHECK_OK(s);
  return op;
}

TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, TF_Operation* r, const char* name,
                     char transpose_a, char transpose_b) {
  TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
  if (transpose_a) {
    TF_SetAttrBool(desc, "transpose_a", 1);
  }
  if (transpose_b) {
    TF_SetAttrBool(desc, "transpose_b", 1);
  }
  TF_AddInput(desc,(TF_Output){l, 0});
  TF_AddInput(desc,(TF_Output){r, 0});
  TF_Operation* op = TF_FinishOperation(desc, s);
  CHECK_OK(s);
  return op;
}

TF_Graph* BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
  //            |
  //           z|
  //            |
  //          MatMul
  //         /       \
  //        ^         ^
  //        |         |
  //    x Const_0  y Const_1
  //
  float const0_val[] = {1.0, 2.0, 3.0, 4.0};
  float const1_val[] = {1.0, 0.0, 0.0, 1.0};
  TF_Operation* const0 = FloatConst2x2(g, s, const0_val, "Const_0");
  TF_Operation* const1 = FloatConst2x2(g, s, const1_val, "Const_1");
  TF_Operation* matmul = MatMul(g, s, const0, const1, "MatMul",0,0);
  inputs[0] = (TF_Output){const0, 0};
  inputs[1] = (TF_Output){const1, 0};
  outputs[0] = (TF_Output){matmul, 0};
  CHECK_OK(s);
  return g;
}

int main(int argc, char const *argv[]) {
  g = TF_NewGraph();
  s = TF_NewStatus();

  TF_Output inputs[2],outputs[1];
  BuildSuccessGraph(inputs,outputs);

  /* HERE traverse g -- maybe with {inputs,outputs} -- to print the graph */

  fprintf(stdout, "OK\n");
}

如果有人可以帮助了解使用哪些函数来获取有关图表的信息,我们将不胜感激。

最佳答案

来自 c_api.h:

// Iterate through the operations of a graph.  To use:
// size_t pos = 0;
// TF_Operation* oper;
// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
//   DoSomethingWithOperation(oper);
// }
TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph,
                                                      size_t* pos);

请注意,这仅返回操作,并没有定义从一个节点(操作)导航到下一个节点的方法 - 此边缘关系存储在节点本身中(作为指针)。

关于tensorflow - 如何使用 C API 遍历 Tensorflow 图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49868231/

相关文章:

python - TensorFlow 二元分类器输出 3 个类而不是 2 个类的预测?

python - 尝试将 GPU 与 Tensorflow 一起使用时出错

tensorflow - 使用MobileNet重新训练图像检测

python-3.x - 教程tensorflow音频音高分析

python - 我在使用 Tensorflow 2.0 运行 RNN LSTM 模型时遇到错误

r - 在R中的glmnet中提取非零系数

tensorflow - tensorflow.python.framework.errors_impl.NotFoundError : Failed to create a directory: ; No such file or directory

python - 使用 scikit-learn 预测电影评论

machine-learning - 如何获取检测到的图像的轮廓(使用 cnn 对象检测)而不是其周围的矩形框?

python - 如何将 2D 列表 python 转换为一个列表并连接它们