我已经使用 python-tensorflow 训练了一个模型,我想在 java-tensorflow 中进行推理。我已将经过训练的模型/图形加载到 Java 中。在此之后,我想永久更新图中的一个变量。我知道 python 中的 tf.variable.load(value,session)
函数可用于更新变量的值。我想知道Java中是否有类似的方法。
到目前为止,我已经尝试了以下方法。
// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)
但在同一行执行的 fetch
调用期间,上面的行仅对 variableName
使用了 updatedTensorValue
。
g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();
上面的行没有更新值,而是试图将相同的变量添加到图中,因此它抛出异常。
另一种永久更新图中变量的替代方法是,我将在所有 fetch
调用期间始终调用 feed(variableName,updatedTensorValue)
方法。我会在多个实例上运行推理代码,所以我想知道这个额外的 feed
调用会花费额外的时间。
谢谢
最佳答案
在 TensorFlow 中做大多数事情的方法是执行一个操作。您在尝试运行 Assign
操作时走在了正确的轨道上,但是调用不正确,因为要分配的 value
不是 的“属性” >Assign
操作而不是输入张量。 (请参阅原始 definition of the operation ,但不可否认,除非您熟悉 TensorFlow 内部结构,否则该定义可能不容易理解)。
但是,您不需要在 Java 中向图形添加操作来执行此操作。相反,你可以完全按照 tf.Variable.load
做在 Python 中执行 - 执行 tf.Variable.initializer
操作,输入输入值。
例如,考虑以下用 Python 构建的图表:
import tensorflow as tf
var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()
# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables: ', init.name)
print('myvar.initializer: ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)
现在,我们在 Java 中复制 Python var.load()
的行为,使用如下方式将值 3.0 赋给变量:
try (Tensor<Float> newValue = Tensors.create(3.0f)) {
s.runner()
.feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
.addTarget("myvar/Assign") // myvar.initializer.name
.run();
}
希望对您有所帮助。
关于java - 永久更新 tensorflow-java 中的变量(在推理期间),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49801711/