python - 如何在 tensorflow session 中使用mean_squared_error损失

标签 python tensorflow machine-learning keras

我是 tensorflow 新手

在 tensorflow session 的代码的一部分中,有:

 loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=net, labels=self.out_placeholder, name='cross_entropy')
                self.loss = tf.reduce_mean(loss, name='mean_squared_error')

我想使用 mean_squared_error 损失函数来实现此目的。我在tensorflow网站上找到了这个损失函数:

tf.losses.mean_squared_error(
labels,
predictions,
weights=1.0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

我需要这个损失函数来解决回归问题。

我尝试过:

loss = tf.losses.mean_squared_error(predictions=net, labels=self.out_placeholder)
self.loss = tf.reduce_mean(loss, name='mean_squared_error')

其中net = tf.matmul(input_tensor,weights) + 偏差

但是,我不确定这是否是正确的方法。

最佳答案

首先请记住,交叉熵主要用于分类,而 MSE 用于回归。

在您的情况下,交叉熵测量两个分布之间的差异(真实发生的情况,称为标签 - 以及您的预测)

因此,虽然第一个损失函数作用于 softmax 层的结果(可以看作是概率分布),但第二个损失函数直接作用于网络的浮点输出(不是概率分布) - 因此它们不能简单地交换。

关于python - 如何在 tensorflow session 中使用mean_squared_error损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53461304/

相关文章:

python - 如何在Python中修剪子树

python - 对 Pandas Dataframes 中的列数据进行分组

python - 尝试创建一个规划算法

php - 每天在多个数据库上自动运行 MySQL 脚本的最佳方法是什么?

python - 图像中的不相关信息对CNN的学习过程有多大影响?

python - 使用 dropout 和 CudnnLSTM 进行训练和验证

python - 如果指定输出数量,如何计算 CNN 中的输出维度

python - 使用 fit_generator 的训练模型不显示 val_loss 和 val_acc 并且在第一个时期中断

python - 如何在Python中计算XGBoost分类器的联合特征贡献?

machine-learning - 如何检测来自各种来源的表格数据