tensorflow - tf.contrib.learn.LinearRegressor 为具有一个特征的数据构建出乎意料的糟糕模型

标签 tensorflow linear-regression tflearn

我正在为 csv 中的数据构建一个简单的线性回归器.数据包括一些人的体重和高度值。整体学习过程非常简单:

MAX_STEPS = 2000
# ...
features = [tf.contrib.layers.real_valued_column(feature_name) for feature_name in FEATURES_COL]
# ...
linear_regressor = tf.contrib.learn.LinearRegressor(feature_columns=features)
linear_regressor.fit(input_fn=prepare_input, max_steps=MAX_STEPS)

然而,出乎意料的是,回归器构建的模型很糟糕。结果可以用下一张图来说明:
enter image description here

可视化代码(以防万一):
plt.plot(height_and_weight_df_filtered[WEIGHT_COL], 
         linear_regressor.predict(input_fn=prepare_full_input), 
         color='blue',
         linewidth=3)

以下是 scikit-learn 为 LinearRegression 类提供的相同数据:
lr_updated = linear_model.LinearRegression()
lr_updated.fit(weight_filtered_reshaped, height_filtered)

和可视化:
enter image description here

增加步数没有效果。我会假设我以错误的方式使用来自 TensorFlow 的回归器。

iPython notebook with the code.

最佳答案

看起来您的 TF 模型确实有效,并且会通过足够的步骤到达那里。不过,您需要立即将其顶起 - 200K 显示出显着的改进,几乎与 sklearn 默认值一样好。

我认为有两个问题:

  • sklearn 看起来像使用普通最小二乘法简单地求解方程。 TF 的线性回归器使用 FtrlOptimizer .该论文表明它是非常大的数据集的更好选择。
  • input_fn模型的每一步都一次注入(inject)整个训练集。这只是一种预感,但我怀疑 FtrlOptimizer 如果一次查看批处理可能会做得更好。

  • 除了将步数提高几个数量级之外,您还可以在优化器上提高学习率(默认值为 0.2),并仅从 4k 步获得类似的良好结果:
    linear_regressor = tf.contrib.learn.LinearRegressor(
        feature_columns=features, 
        optimizer=tf.train.FtrlOptimizer(learning_rate=5.0))
    

    关于tensorflow - tf.contrib.learn.LinearRegressor 为具有一个特征的数据构建出乎意料的糟糕模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40734394/

    相关文章:

    python - tflearn to_categorical : Processing data from pandas. df.values:数组数组

    TensorFlow:改为 "TypeError: Expected int32, got list containing Tensors of type ' _Message'

    python - 这个CNN的降维似乎违背了我对理论的理解

    r - 如何在 for 循环或 s/lapply 中添加线性回归?

    Python指数/线性曲线拟合

    python - 如何在 Python 中进行指数和对数曲线拟合?我发现只有多项式拟合

    tensorflow - 如何在Tensorflow中使用LSTM模型生成例句?

    python - 如何查看 Autogluon 训练的模型的详细信息?

    python - Tensorflow 和 TFlearn 错误 - 意外参数 'keepdims'

    python - 功能和标签尺寸崩溃 (tflearn)