我正在使用 TensorFlow 进行一些实验,但遇到了障碍。我正在尝试使用 TF 来评估模型的变化,然后根据损失函数的结果变化保留或恢复模型。我已经弄清楚了困难的部分(条件控制),但我仍然坚持一些应该相当简单的事情:我似乎无法为迭代存储 tf.trainable_variables
,然后如果需要,恢复它。
假设构建一个 Op:
...
store_trainable_vars = []
for v in tf.trainable_variables():
store_trainable_vars.append(v)
...
然后,我想将 tf.trainable_variables
恢复到上次运行此 Op 时的值。我想做类似的事情:
def reject_move():
revert_state = []
for (v, s) in zip(tf.trainable_variables(), store_trainable_vars):
revert_state.append(tf.assign(v, s, name="revert_state"))
return(revert_state)
显然,这将重新评估 store_trainable_vars
,它又链接到 tf.trainable_variables()
的当前值,避免 revert_state
操作。我需要一些方法来存储和检索张量的值,而无需回调这些张量的当前值。有点像
...
store_trainable_vars = []
for v in tf.trainable_variables():
store_trainable_vars.append(v.value_right_now())
...
其中 v.value_right_now()
返回一个常量,该常量在被覆盖之前不会改变。
我知道我可以使用 Saver,但该解决方案会写入磁盘,这对于该应用程序来说是 Not Acceptable ,因为它将在训练循环内运行。
我可能遗漏了一些明显的东西 - 任何指导将不胜感激。
最佳答案
要手动恢复图形状态,您需要使用 tf.tuple
或 tf.group
操作,这将修改批量更改的流程:
This creates a tuple of tensors with the same values as the tensors argument, except that the value of each tensor is only returned after the values of all tensors have been computed.
[更新] 以下是我的做法:
import numpy as np
import tensorflow as tf
x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x')
W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W')
b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b')
y = tf.add(tf.matmul(x, W), b)
with tf.Session() as session:
batch = np.ones([2, 5])
session.run(tf.global_variables_initializer())
print session.run(y, feed_dict={x: batch}) # prints [2, 5] zeros
# store the current value
store = {v.name: v.eval(session) for v in tf.trainable_variables()}
print store # prints [5, 5] and [5] zeros
# update
new = {'W:0': np.ones([5, 5]), 'b:0': np.ones([5])}
session.run(tf.tuple([tf.assign(var, new[var.name]) for var in tf.trainable_variables()]))
print session.run(y, feed_dict={x: batch}) # prints [2, 5] sixes
# restore
session.run(tf.tuple([tf.assign(var, store[var.name]) for var in tf.trainable_variables()]))
print session.run(y, feed_dict={x: batch}) # prints [2, 5] zeros again
但我真的认为您应该重新考虑关于 Saver
的决定,因为它也被设计为在训练循环中使用。在内部,Saver
为您完成所有棘手的工作(特别是,如果需要,它会恢复操作调用 tf.group
和 tf.control_dependencies
),否则可能会成为非常讨厌的错误的来源。此外,磁盘(几乎)总是比你的 GPU 和主内存大,所以如果你有能力将模型存储在内存中,你也应该能够存储在磁盘上。
这是 some parameters有助于控制磁盘上检查点文件的扩散:
max_to_keep
表示最近要保存的检查点文件的最大数量 保持。随着新文件的创建,旧文件将被删除。如果为 None 或 0,则保留所有检查点文件。默认为 5(即最近的 5 个 保留检查点文件)。keep_checkpoint_every_n_hours
:除了保持最近的max_to_keep
检查点文件,你可能想保留一个检查点文件 每 N 小时的训练。如果您想稍后使用,这可能很有用 分析模型在长时间训练期间的进展情况。为了 例如,传递keep_checkpoint_every_n_hours=2
可确保您为每 2 小时的训练保留一个检查点文件。默认值 10,000 小时有效地禁用了该功能。
[更新] 如评论中所述,主要问题是磁盘延迟,如果访问过于频繁,可能会减慢训练速度。如果您使用的是 Linux,它 caches经常使用的磁盘页面,Windows does it以及。但是,如果您想绝对确定,请考虑使用 tmpfs
.
关于python - 如何在不将值保存到磁盘的情况下将张量恢复到过去的值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46393983/