java - 无法在 Java API 中运行 Tensorflow 预测

标签 java python tensorflow

我正在尝试对使用“使用 TensorFlow 微调 AlexNet”训练的模型执行预测 https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html

我在 Python 中使用 tf.saved_model.builder.SavedModelBuilder 保存模型,并使用 SavedModelBundle.load 在 Java 中加载模型。 代码的主要部分是:

    SavedModelBundle smb = SavedModelBundle.load(path, "serve");
    Session s = smb.session();
    byte[] imageBytes = readAllBytesOrExit(Paths.get(path));
    Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
    Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0);
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
        throw new RuntimeException(
                String.format(
                        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                        Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float [] a =  result.copyTo(new float[1][nlabels])[0];`

我收到此异常:

Exception in thread "main" java.lang.IllegalArgumentException: You must feed a value for placeholder tensor 'Placeholder_1' with dtype float [[Node: Placeholder_1 = Placeholder_output_shapes=[[]], dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]]

我看到上面的代码对某些人有用,但我不知道这里缺少什么。 请注意,网络熟悉节点“input_tensor”和“fc8/fc8”,因为它没有说它不知道它们。

最佳答案

从错误消息来看,您使用的模型似乎需要提供另一个值(图中的节点名称为 Placeholder_1,预期类型为浮点标量张量)。

您似乎已经自定义了您的模型(而不是逐字遵循您链接到的文章)。也就是说,本文显示了需要输入的多个占位符,一个用于图像,另一个用于控制丢失。文章中定义为:

keep_prob = tf.placeholder(tf.float32)

并且需要提供该占位符的值。如果您正在进行推理,那么您需要将 keep_prob 设置为 1.0。像这样的东西:

Tensor keep_prob = Tensor.create(1.0f);
Tensor result = s.runner()
  .feed("input_tensor", image)
  .feed("Placeholder_1", keep_prob)
  .fetch("fc8/fc8")
  .run()
  .get(0);

希望有帮助。

关于java - 无法在 Java API 中运行 Tensorflow 预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44648065/

相关文章:

java - 如何扩展 FutureTask 并确保释放对 Callable 的引用?

python "if"与 "and"函数执行顺序

graph - 如何在tensorflow MNIST教程中输出预测值(标签)?

java - 我被困在java代码中,我不明白为什么它不返回true

java - 如何在 JAVA 中的 diff 原因上打印相同 NumberFormatException 的 diff 消息?

python - 如何修复 boto3 中不存在用户池 ********

python - 如何在 'for x in range' 语句中压缩 'if' 语句 'elif' 语句中的 'while' 语句

python - 为什么 TensorFlow 2 和 1 的 RNG 不同?

python - Tensorflow 中的 Dice/Jaccard 系数优化

java - 所有AsyncTasks完成后如何执行onMapReady()回调?