machine-learning - TensorFlow 均方误差指标始终返回 0

标签 machine-learning tensorflow statistics

我遇到一个问题,无论我传递什么标签/预测 TF.Metrics.Mean_Squared_Error 它总是返回 0 值。

这是重复问题的代码:

a = tf.constant([0,0,0,0])
b = tf.constant([1,1,1,1])
mse, update = tf.metrics.mean_squared_error(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
mse.eval(session=sess) 

%% 返回 0.0

最佳答案

我真的不知道为什么它会这样工作,但实际上您需要在 mse 的内部状态考虑您的数据之前运行 update :

a = tf.constant([0,0,0,0])
b = tf.constant([1,1,1,1])
mse, update = tf.metrics.mean_squared_error(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
mse.eval(session=sess) # Gives 0.0, the initial MSE value
update.eval(session=sess) # Gives 1.0, the value of the update for the mse
mse.eval(session=sess)  # Gives 1.0, which is 0.0+1.0, the updated mse value
例如,

tf.metrics.mean_squared_error() 用于计算整个数据集的 MSE,因此如果您想要独立地获得批处理结果,则不应使用它。为此,请使用 tf.losses.mean_squared_error(a, b, loss_collection=None)例如。

关于machine-learning - TensorFlow 均方误差指标始终返回 0,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47020901/

相关文章:

tensorflow - 在tf-slim中实现混合精度训练

boost - 基于分布的弱学习器 : Decision stump

python-2.7 - 如何解决以下错误?

python - 为什么使用 tf.py_function 的模型无法序列化?

sql - 在 SQL (PSQL) 中,如何按行的分区进行分组(如何嵌套分组)?

tensorflow - 如何为 keras 层编写 lambda 函数,用于向量矩阵乘法

python - TensorFlow - 切片张量结果为 : ValueError: Shape (16491, )必须具有等级 3

python - Keras 似乎在调用 fit_generator 后挂起

python - 如何从经验分布函数中抽样

python - Scipy.stats T 分布的置信区间与手动计算时不同