我用 bazel 构建了一个非常简单的自定义操作 zero_out.dll
,它在使用 python 时有效。
import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.dll')
with tf.Session(''):
zero_out_module.zero_out([[1, 2], [3, 4]]).eval()
但是我必须使用C++运行推理,是否有任何c++ api具有与tf.load_op_library
类似的功能,因为似乎在中完成了很多注册工作>tf.load_op_library
,TF有没有对应的c++ API?
最佳答案
虽然在 C++ 中似乎没有公共(public) API,但库加载函数公开在 TensorFlow API for C 中。 (这是 tf.load_library
使用的 API)。它没有“好的”文档,但您可以在 c/c_api.h
中找到它们:
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels
// TF_Library holds information about dynamically loaded TensorFlow plugins.
typedef struct TF_Library TF_Library;
// Load the library specified by library_filename and register the ops and
// kernels present in that library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, place OK in status and return the newly created library handle.
// The caller owns the library handle.
//
// On failure, place an error status in status and return NULL.
TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename,
TF_Status* status);
// Get the OpList of OpDefs defined in the library pointed by lib_handle.
//
// Returns a TF_Buffer. The memory pointed to by the result is owned by
// lib_handle. The data in the buffer will be the serialized OpList proto for
// ops defined in the library.
TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
这些函数确实调用了 C++ 代码(参见 c/c_api.cc
中的源代码)。然而,被调用的函数,在 core/framework/load_library.cc
中定义没有要包含的标题。在 C++ 代码中使用它的解决方法,他们在 c/c_api.cc
中使用,就是自己声明函数,链接TensorFlow库。
namespace tensorflow {
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
}
据我所知,没有用于卸载库的 API。 C API 只允许您删除库句柄对象。这只是通过释放指针来完成的,但如果你想避免麻烦,你应该使用 TensorFlow 提供的释放函数,tensorflow::port:free
,在 core/platform/mem.h
中声明。 .同样,如果您不能或不想包含它,您可以自己声明该函数,它应该也能正常工作。
namespace tensorflow {
namespace port {
void Free(void* ptr);
}
}
关于c++ - 使用 C++ 时如何加载自定义操作库?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57857325/