python - tensorflow 中 GradientDescentOptimizer 和 AdamOptimizer 的区别?

标签 python machine-learning tensorflow regression gradient-descent

当使用 GradientDescentOptimizer 而不是 Adam Optimizer 时,模型似乎没有收敛。另一方面,AdamOptimizer 似乎工作正常。 tensorflow 的 GradientDescentOptimizer 有问题吗?

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

def randomSample(size=100):
    """
    y = 2 * x -3
    """
    x = np.random.randint(500, size=size)
    y = x * 2  - 3 - np.random.randint(-20, 20, size=size)    

    return x, y

def plotAll(_x, _y, w, b):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(_x, _y)

    x = np.random.randint(500, size=20)
    y = w * x + b
    ax.plot(x, y,'r')
    plt.show()

def lr(_x, _y):

    w = tf.Variable(2, dtype=tf.float32)
    b = tf.Variable(3, dtype=tf.float32)

    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)

    linear_model = w * x + b
    loss = tf.reduce_sum(tf.square(linear_model - y))
    optimizer = tf.train.AdamOptimizer(0.0003) #GradientDescentOptimizer
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    for i in range(10000):
        sess.run(train, {x : _x, y: _y})
    cw, cb, closs = sess.run([w, b, loss], {x:_x, y:_y})
    print(closs)
    print(cw,cb)

    return cw, cb

x,y = randomSample()
w,b = lr(x,y)
plotAll(x,y, w, b)

最佳答案

我曾经遇到过类似的问题,我花了很长时间才找出真正的问题。通过梯度下降,我的损失函数实际上在增长而不是变小。

原来是我的学习率太高了。如果你对梯度下降采取太大的步骤,你最终可能会跳过最小值。如果你真的很倒霉,就像我一样,你最终会跳得太远,以至于你的错误会增加。

降低学习率应该会使模型收敛。但这可能需要很长时间。

Adam 优化器具有动量,也就是说,它不仅遵循瞬时梯度,而且还以某种速度。这样,如果你因为梯度而开始来回移动,那么动量将迫使你在这个方向上走得更慢。这很有帮助! Adam 除了 momentum 之外还有几个星期,使其成为首选的深度学习优化器。

如果您想阅读有关优化器的更多信息,这篇博文内容非常丰富。 http://ruder.io/optimizing-gradient-descent/

关于python - tensorflow 中 GradientDescentOptimizer 和 AdamOptimizer 的区别?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46255221/

相关文章:

python - 给定索引 ndarray 和标志 ndarray 是否有任何 numpy/torch 样式来设置值?

machine-learning - 神经网络输出变化不大

machine-learning - 在机器学习中使用反馈或强化?

python - Sublime Text 将 PATH 设置为预安装的 Anaconda 和 Tensorflow

python - W tensorflow/core/common_runtime/gpu/gpu_device.cc :1598] Cannot dlopen some GPU libraries

Python 监视文件更改并识别用户

Python Fire 连字符与下划线

machine-learning - 人工神经网络无法识别 Paint 制作的图像

tensorflow - tf.global_variables_initializer 的目的是什么?

python - 如何在没有互联网的计算机上安装 python-dev 包?