python - 如何在内存中复制 tensorflow 网络状态以便参数更新后检索?

标签 python tensorflow

我有一个包含多个模块的 tensorflow 图,我想重用其中一个模块的之前的网络状态(参数更新之前)来评估下一个状态的输入 em>(参数更新后)。

示例

考虑一下玩具示例,我希望在时间步 t 处复制 network_B,以便在下一个训练步骤 t+1 中使用:

def network_A(x):
    A1 = tf.matmul(x, A_W1) + A_b1
    return tf.nn.relu(A1)

def network_B(x):
    B1 = tf.matmul(x, B_W1) + B_b1
    Z1 = tf.nn.relu(B1)
    B2 = tf.matmul(Z1, B_W2) + B_b2
    return B2

x = tf.placeholder(tf.float32, shape=[None, x_dim])
x_2 = network_A(x)

# Evaluate input x_2 with current state of network
y_hatB_current = network_B(x)

# Evaluate same input x_2 with past state of network
y_hatB_past = network_B_past(x) # 

# Get some loss
loss = ...

然后,一旦两者都被评估,将网络的当前状态保存为新的过去状态,并仅优化当前状态:

# Save state of parameters
network_B_past = network_B  # (How do I do this efficiently?)

# Optimize the current state
train = tf.train.AdamOptimizer().minimize(loss, var_list=current_vars)

详细信息

因此,在每个训练步骤中,应该存在两个版本的 network_B 可用于评估输入:

  1. network_B 在时间步 t-1(过去状态)
  2. network_B 在时间步t(当前状态)

在两个训练步骤之间存在参数更新,因此两者之间的权重应该略有不同,但其他方面应该相同。然后,在评估新输入后,当前状态将替换过去的状态,并且另一个训练步骤将更新网络。

我知道我可以在 tensorflow 中保存和重新加载检查点,但这对于我的用例来说似乎效率太低,因为它需要在每个训练步骤中发生。实现此网络克隆步骤以便我维护跨州持续存在的副本的有效方法是什么?

tensorflow 版本:1.5

最佳答案

我将使用函数 create_graph 在不同的变量范围下创建网络两次:一次用于当前,一次用于备份。请注意,这会使内存消耗加倍。

那么您所需要的只是一个自定义sync_op。 MWE 是

import tensorflow as tf


def copy_vars(src_scope, dst_scope):
    src_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope)
    dst_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=dst_scope)

    update_op = []

    for src_var in src_vars:
        for dst_var in dst_vars:
            if src_var.name.replace('%s' % src_scope, '') == dst_var.name.replace('%s' % dst_scope, ''):
                assert dst_var.shape == src_var.shape
                print("  copy: add assign {} -> {}".format(src_var.name, dst_var.name))
                update_op.append(dst_var.assign(src_var))
    return tf.group(update_op)


def create_graph(name, x, use_c=False, uses_gradient_updates=True):

    var_setter = lambda x: x  # noqa
    if uses_gradient_updates:
        var_setter = lambda x: tf.stop_gradient(x)  # noqa

    with tf.variable_scope(name, custom_getter=var_setter):
        a = tf.Variable([1], dtype=tf.float32)
        b = tf.Variable([1], dtype=tf.float32)
        result = x + a + b
        if use_c:
            # create dummy variable just to show both graphs do not need to be exactly the same
            c = tf.Variable([1], dtype=tf.float32)
        return result, a, b

x = tf.placeholder(tf.float32)

c1, a1, b1 = create_graph('original', x, use_c=True)
c2, a2, b2 = create_graph('backup', x, use_c=False)

sync_op = copy_vars('original', 'backup')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print(sess.run([c1, c2], {x: 5}))  # in sync

    sess.run(a1.assign([3]))  # update your graph either by tf.train.Adam or by:
    print(sess.run([c1, c2], {x: 5}))  # out of sync

    sess.run(sync_op)  # do syncing
    print(sess.run([c1, c2], {x: 5}))  # in sync

custom_getter 可以帮助防止渐变更新。

关于python - 如何在内存中复制 tensorflow 网络状态以便参数更新后检索?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48555964/

相关文章:

python - python中将多个特定文本文件转换为CSV

python - 学习 Django - 很好的入门项目

python - Pillow/Numpy 中的高效粘贴/复合蒙版图像

python - 对于 Pandas DataFrame 中任何带有 NaN 的行,移位 1

python - 为什么 tf.cond() 将 tf.bool 识别为 python bool 而不是 tf.bool?

python - 在 Django Celery 中,如何判断任务是否已异步执行

tensorflow 每次运行发现多个图形事件

tensorflow - 如何从索引值中获取值

python - 内存减少 Tensorflow TPU v2/v3 bfloat16

python - Tensorflow 中的可微分函数?