c++ - 使用 C++ 时如何加载自定义操作库?

标签 c++ tensorflow

我用 bazel 构建了一个非常简单的自定义操作 zero_out.dll,它在使用 python 时有效。

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.dll')
with tf.Session(''):
  zero_out_module.zero_out([[1, 2], [3, 4]]).eval()

但是我必须使用C++运行推理,是否有任何c++ api具有与tf.load_op_library类似的功能,因为似乎在中完成了很多注册工作>tf.load_op_library,TF有没有对应的c++ API?

最佳答案

虽然在 C++ 中似乎没有公共(public) API,但库加载函数公开在 TensorFlow API for C 中。 (这是 tf.load_library 使用的 API)。它没有“好的”文档,但您可以在 c/c_api.h 中找到它们:

// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels

// TF_Library holds information about dynamically loaded TensorFlow plugins.
typedef struct TF_Library TF_Library;

// Load the library specified by library_filename and register the ops and
// kernels present in that library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, place OK in status and return the newly created library handle.
// The caller owns the library handle.
//
// On failure, place an error status in status and return NULL.
TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename,
                                                 TF_Status* status);

// Get the OpList of OpDefs defined in the library pointed by lib_handle.
//
// Returns a TF_Buffer. The memory pointed to by the result is owned by
// lib_handle. The data in the buffer will be the serialized OpList proto for
// ops defined in the library.
TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);

// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);

这些函数确实调用了 C++ 代码(参见 c/c_api.cc 中的源代码)。然而,被调用的函数,在 core/framework/load_library.cc 中定义没有要包含的标题。在 C++ 代码中使用它的解决方法,他们在 c/c_api.cc 中使用,就是自己声明函数,链接TensorFlow库。

namespace tensorflow {
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
                   const void** buf, size_t* len);
}

据我所知,没有用于卸载库的 API。 C API 只允许您删除库句柄对象。这只是通过释放指针来完成的,但如果你想避免麻烦,你应该使用 TensorFlow 提供的释放函数,tensorflow::port:free,在 core/platform/mem.h 中声明。 .同样,如果您不能或不想包含它,您可以自己声明该函数,它应该也能正常工作。

namespace tensorflow {
namespace port {
void Free(void* ptr);
}
}

关于c++ - 使用 C++ 时如何加载自定义操作库?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57857325/

相关文章:

c++ - 你能给我一个 16 位(或更多)十进制数,它只在第 15 位正确转换为 double float 吗?

c++ - 如何禁用/覆盖 OpenSceneGraph 中的 Wireframe 和 Stats 键?

python - 如何包装 C++ 代码以供 IronPython 访问

python - 如何将 Tensorflow Simple Audio Recognition frozen graph(.pb) 转换为 Core ML 模型?

python - tf.data.Dataset 对象作为 tf.Keras 模型的输入 -- ValueError

tensorflow - 发电机在错误的时间调用(keras)

C++ 多线程 : Do I need mutex for constructor and destructor?

c++ - 关于使用共享指针的求值顺序

python - TensorFlow 缺少 FFT 的 CPU 运算(InvalidArgumentError : No OpKernel was registered to support Op 'FFT' with these attrs)

python - 具有稀疏数据的 tensorflow 训练