python - TensorFlow 'global_step' 变量未针对指数衰减进行更新

标签 python tensorflow neural-network

我有一个网络,我在其中使用学习率的指数衰减。为此,我正在跟踪一个“global_step”TF 变量,该变量在处理的每个批处理中递增 1。然而,看起来实际上,它并没有真正得到更新。这是代码。

...
global_step = tf.Variable(0, trainable=False, name='global_step')
starter_learning_rate = 0.01
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.50)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optm = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step)
init = tf.global_variables_initializer()


def train(file):
    global global_step
    for batch in batches:
        global_step += 1
        ...        
    return loss

...
global_step = 0
for epoch in EPOCHS:
    for f in files:
        loss = train(f)

函数内部和外部的 global_step 正在更新。但是我的学习率没有改变。当我将摘要附加到我的 TF global_step 变量时,我看到它始终保持为 0。

这里有什么问题?

最佳答案

实际上,我还没有看到你在哪里设置learning_rate变量,但这是如何使用它的方式:

定义全局步骤变量

global_step = tf.Variable(0)

定义学习率随不同params变化的方式

learning_rate = tf.train.exponential_decay(0.1, global_step, 500, 0.7, staircase=True)

将它们传递给优化器

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

关于python - TensorFlow 'global_step' 变量未针对指数衰减进行更新,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42755157/

相关文章:

python - 使用 MLPRegressor 拟合简单数据时遇到问题

python - 使用 Flask 的 URL 结构和表单帖子

python - 简单的单层神经网络

tensorflow - 如何在 tensorflow 中复制变量

tensorflow - 在 Tensorflow 中将 RGBA 图像转换为黑白图像

python - Wasserstein 损失可以是负数吗?

tensorflow - Keras问题下的反卷积

python - Python 3 类中的复数除法

python - 简单的URL映射问题

python - 对于以下代码,将文件内容与整数变量进行比较会产生 false