java - 如何从 Java 中的示例对象创建张量?

标签 java tensorflow tensorflow-serving

我的用例: 我正在尝试使用 libtensorflow_jni 在我们现有的 JVM 服务中为 python 训练的模型提供服务。

现在我可以使用 SavedModelBundle.load() 加载模型。但我发现很难将请求输入到模型中。因为我的用户请求不仅仅是一个标量矩阵,而是一个特征图,例如:

{'gender':1, 'age': 20, 'country': 100, other features ...}

通过搜索 tensorflow 教程,我发现Example Protocol Buffer 可能适合这里,因为它基本上包含一个功能列表。但我不知道如何将其转换为 Java Tensor 对象。

如果我直接使用序列化的示例对象创建张量,TensorFlow 运行时似乎对数据类型不满意。例如,我这样做,

Tensor inputTensor = Tensor.create(example.toByteArray());
s.runner().feed(inputTensorName, inputTensor).fetch(outputTensorName).run().get(0);

我将得到一个 IllegalArgumentException:

java.lang.IllegalArgumentException: Expected serialized to be a vector, got shape: []

如果你们碰巧知道或有相同的用例,你们能否告诉我如何从这里继续前进?

谢谢!

最佳答案

查看您的错误消息,问题似乎在于您的模型需要字符串张量 vector (很可能对应于一批序列化 Example Protocol Buffer 消息,可能来自 tf.parse_example )但你给它提供了一个标量弦张量。

不幸的是,在解决 issue #8531 之前,Java API 无法创建除标量之外的字符串张量。一旦这个问题得到解决,事情就会变得更容易。

同时,您可以通过构建 TensorFlow“模型”将标量字符串转换为大小为 1 的 vector 来解决此问题:)。这可以通过这样的事情来完成:

// A TensorFlow "model" that reshapes a string scalar into a vector.
// Should be much prettier once https://github.com/tensorflow/tensorflow/issues/7149
// is resolved.
private static class Reshaper implements AutoCloseable {
  Reshaper() {
    this.graph = new Graph();
    this.session = new Session(graph);
    this.in =
        this.graph.opBuilder("Placeholder", "in")
            .setAttr("dtype", DataType.STRING)
            .build()
            .output(0);
    try (Tensor shape = Tensor.create(new int[] {1})) {
      Output vectorShape =
          this.graph.opBuilder("Const", "vector_shape")
              .setAttr("dtype", shape.dataType())
              .setAttr("value", shape)
              .build()
              .output(0);
      this.out =
          this.graph.opBuilder("Reshape", "out").addInput(in).addInput(vectorShape).build().output(0);
    }
  }

  @Override
  public void close() {
    this.session.close();
    this.graph.close();
  }

  public Tensor vector(Tensor input) {
    return this.session.runner().feed(this.in, input).fetch(this.out).run().get(0);
  }

  private final Graph graph;
  private final Session session;
  private final Output in;
  private final Output out;
}

通过上述内容,您可以将示例原始张量转换为 vector ,并将其输入到您感兴趣的模型中,如下所示:

Tensor inputTensor = null;
try (Tensor scalar = Tensor.create(example.toByteArray())) {
  inputTensor = reshaper.vector(scalar);
}
s.runner().feed(inputTensorName, inputTensor).fetch(outputTensorName).run().get(0);

有关完整详细信息,see this example on github

希望有帮助!

关于java - 如何从 Java 中的示例对象创建张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45746742/

相关文章:

tensorflow - 在 model_config_file 中分配版本标签失败

python - Tensorflow 2.0具体函数structured_input_signature返回值

docker - TensorFlow 通过 docker 服务多个模型

java - 是否可以在 Java 中创建内存中的文件对象?

java - 解析 Java 中的类型是什么意思?

python - Tensorflow 无法为形状为 'x:0' 的张量 '(?, 128)' 提供 shape(1,) 的值

tensorflow - 加载模型权重后的 CuDNN 库兼容性错误

java - 无法从文件读回 protobuf

java - Intellij 需要很长时间才能执行 JUnit 测试的 Maven 目标

python - 在 tensorflow 中扩展维度并复制数据