java - 永久更新 tensorflow-java 中的变量(在推理期间)

标签 java python variables tensorflow

我已经使用 python-tensorflow 训练了一个模型,我想在 java-tensorflow 中进行推理。我已将经过训练的模型/图形加载到 Java 中。在此之后,我想永久更新图中的一个变量。我知道 python 中的 tf.variable.load(value,session) 函数可用于更新变量的值。我想知道Java中是否有类似的方法。

到目前为止,我已经尝试了以下方法。

// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)

但在同一行执行的 fetch 调用期间,上面的行仅对 variableName 使用了 updatedTensorValue

g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();

上面的行没有更新值,而是试图将相同的变量添加到图中,因此它抛出异常。

另一种永久更新图中变量的替代方法是,我将在所有 fetch 调用期间始终调用 feed(variableName,updatedTensorValue) 方法。我会在多个实例上运行推理代码,所以我想知道这个额外的 feed 调用会花费额外的时间。

谢谢

最佳答案

在 TensorFlow 中做大多数事情的方法是执行一个操作。您在尝试运行 Assign 操作时走在了正确的轨道上,但是调用不正确,因为要分配的 value 不是 的“属性” >Assign 操作而不是输入张量。 (请参阅原始 definition of the operation ,但不可否认,除非您熟悉 TensorFlow 内部结构,否则该定义可能不容易理解)。

但是,您不需要在 Java 中向图形添加操作来执行此操作。相反,你可以完全按照 tf.Variable.load 做在 Python 中执行 - 执行 tf.Variable.initializer操作,输入输入值。

例如,考虑以下用 Python 构建的图表:

import tensorflow as tf

var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()

# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables:         ', init.name)
print('myvar.initializer:          ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)

现在,我们在 Java 中复制 Python var.load() 的行为,使用如下方式将值 3.0 赋给变量:

try (Tensor<Float> newValue = Tensors.create(3.0f)) {
  s.runner()
    .feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
    .addTarget("myvar/Assign")           // myvar.initializer.name
    .run();
}

希望对您有所帮助。

关于java - 永久更新 tensorflow-java 中的变量(在推理期间),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49801711/

相关文章:

java - 我在代码中是否正确使用了 'this' 关键字?

java - 两个或多个点的正则表达式应分隔为点空间

python - 在Python中的字符串中添加空格?

python - 避免日志溢出(cosh(x))

ruby - 单个哈希的多变量赋值

java - Spring Boot OAuth 字符转义

java - IS ResultSet 线程安全

python - 在 pandas 中查找满足特定条件的列的有效方法

c++ - 函数重载与函数变量初始化

php - Jquery load() 和 PHP 变量