java - Python 中的 Tensorflow Java Api `toGraphDef` 等价物是什么?

标签 java scala tensorflow java-native-interface tensorflow-serving

我正在使用 Tensorflow Java Api 将已创建的 Tensorflow 模型加载到 JVM 中。 我以此为例:tensorflow/examples/LabelImage.java

这是我的简单 scala 代码:

import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}

def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))

如何保存我的模型以将 session 和图形存储在同一个文件中。如上面的“PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb”中所述。

描述here它提到:

The serialized representation of the graph, often referred to as a GraphDef, can be generated by toGraphDef() and equivalents in other language APIs.

其他语言 API 中的等价物是什么?我觉得不明显

注意:我已经查看了 tensorflow_serving 下的 mnist_saved_model.py,但通过该过程保存它会得到一个 .pb 文件和一个 variables 文件夹。尝试加载该 .pb 文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef

最佳答案

目前使用 tensorflow 的 Java API,我只找到了如何将图形保存为 graphDef(即没有其变量和元数据)。这可以通过将 Array[Byte] 写入文件来完成:

Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)

这里 myGraph 是来自 Graph class 的 java 对象.

我建议使用 SavedModel 从 Python API 保存您的模型api在这里定义。它会将您的模型保存在一个文件夹中,其中包含 .pb 文件中的序列化图形和文件夹中的变量。请注意您使用的 tag_constants,因为您在 scala/java 代码中需要它来加载带有变量的模型。然后带有变量的图形和 session 很容易加载 SavedModelBundle来自 java api 的 java 类。它返回一个包装器,其中包含图形和包含变量值的 session :

val model = SavedModelBundle.load(modelDir, modelTag)

如果您已经尝试过此操作,也许您可​​以分享您的代码以了解它返回无效 GraphDef 的原因。

另一种选择是卡住你的图表,即将你的变量节点变成常量节点,这样一切都在 .pb 文件中是独立的。更多信息 here冷冻部分

关于java - Python 中的 Tensorflow Java Api `toGraphDef` 等价物是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43242857/

相关文章:

python - 无法导入 Tensorflow "No module named copyreg"

python - Firebase Tensorflow Lite 分类模型未在 Swift 应用程序中提供正确的输出

java - GWT 的 http 请求生成器返回空响应

scala - K均值||用于 Spark 情绪分析

Scala 太冗长了

algorithm - Scala 中的通用快速排序

python - Elmo 是词嵌入还是句子嵌入?

java - SLF4J 错误 : class loader have different class objects for the type

java - war 库中的爆炸 jar

java - Java RMI 是否使用服务器资源?