tensorflow - tensorflow 上的线性回归模型无法学习偏差

标签 tensorflow

我正在尝试使用一些生成的数据在 Tensorflow 中训练线性回归模型。该模型似乎学习了直线的斜率,但无法学习偏差。

我试过更改编号。 epoch 的权重(斜率)和偏差,但每次 ,模型学习到的偏差都为零。我不知道哪里出错了,希望能提供一些帮助。

这是代码。

import numpy as np
import tensorflow as tf

# assume the linear model to be Y = W*X + b
X = tf.placeholder(tf.float32, [None, 1])
Y = tf.placeholder(tf.float32, [None,1])
#  the weight and biases
W = tf.Variable(tf.zeros([1,1]))
b = tf.Variable(tf.zeros([1]))
# the model
prediction = tf.matmul(X,W) + b
# the cost function
cost = tf.reduce_mean(tf.square(Y - prediction))
# Use gradient descent

learning_rate = 0.000001
train_step = 
tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
steps = 1000
epochs = 10
Verbose = False
# In the end, the model should learn these values
test_w = 3
bias = 10

for _ in xrange(epochs):  
    for i in xrange(steps):
    #     make fake data for the model
    #     feed one example at a time
#     stochastic gradient descent, because we only use one example at a time
        x_temp = np.array([[i]])
        y_temp = np.array([[test_w*i + bias]])
    #     train the model using the data
        feed_dict = {X: x_temp, Y:y_temp}
        sess.run(train_step,feed_dict=feed_dict)
        if Verbose and i%100 == 0:
            print("Iteration No: %d" %i)
            print("W = %f" % sess.run(W))
            print("b = %f" % sess.run(b))

print("Finally:")
print("W = %f" % sess.run(W))
print("b = %f" % sess.run(b))
# These values should be close to the values we used to generate data

https://github.com/HarshdeepGupta/tensorflow_notebooks/blob/master/Linear%20Regression.ipynb

输出在代码的最后一行。 模型需要学习 test_w 和 bias(在 notebook 链接中,在第 3 个单元格中,在第一个评论之后),分别设置为 3 和 10。

模型正确地学习了权重(斜率),但无法学习偏差。哪里出错了?

最佳答案

主要问题是您一次只向模型提供一个样本。这使您的优化器非常不稳定,这就是为什么您必须使用如此小的学习率。我会建议您在每个步骤中添加更多样本。

如果您坚持一次提供一个样本,也许您应该考虑使用具有动量的优化器,例如 tf.train.AdamOptimizer(learning_rate)。通过这种方式,您可以提高学习率并达到收敛。

关于tensorflow - tensorflow 上的线性回归模型无法学习偏差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44842537/

相关文章:

python-3.x - TPU : double, 不支持的数据类型由输出 IteratorGetNext:0 引起

tensorflow - 选择满足特定条件的tensorflow中的索引

tensorflow - 图像的深度学习异常检测

python-3.x - 教程tensorflow音频音高分析

tensorflow - 无法使用 bazel 从源代码构建 TensorFlow。 2016 年 1 月 22 日

python - 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?

python - 神经网络关于输入的导数

python - 在 Tensorflow 中计算两组向量的余弦相似度

python - tf.get_collection 提取一个作用域的变量

tensorflow - 我该如何解决这个 : "RuntimeError: Attempted to use a closed Session."