我使用 keras 创建了一个模型,并使用 train_on_batch
对其进行了训练。为了检查模型是否达到了预期的效果,我使用 predict_on_batch
方法重新计算了训练阶段之前和之后的损失。但是,正如您在阅读标题时猜到的那样,我没有相同的输出损失。
下面是一个基本代码来说明我的问题:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import tensorflow as tf
import numpy as np
# Loss definition
def mse(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true-y_pred))
# Model definition
model = Sequential()
model.add(Dense(1))
model.compile('rmsprop',mse)
# Data creation
batch_size = 10
x = np.random.random_sample([batch_size,10])
y = np.random.random_sample(batch_size)
# Print loss before training
y_pred = model.predict_on_batch(x)
print("Before: " + str(mse(y,y_pred).numpy()))
# Print loss output from train_on_batch
print("Train output: " + str(model.train_on_batch(x,y)))
# Print loss after training
y_pred = model.predict_on_batch(x)
print("After: " + str(mse(y,y_pred).numpy()))
使用此代码我得到以下输出:
Before: 0.28556848
Train output: 0.29771945
After: 0.27345362
我认为训练损失和训练后计算的损失应该是相同的。所以我想了解为什么不呢?
最佳答案
这就是train_on_batch
的工作原理,它计算损失,然后更新网络,因此我们在网络更新之前得到损失。
当我们应用 predict_on_batch
时,我们会从更新的网络中获取预测。
在底层,train_on_batch 还可以做更多事情,例如修复数据类型、标准化数据等。
train_on_batch
最接近的兄弟是 test_on_batch
。如果您运行 test_on_batch
,您会发现结果接近于 train_on_bacth
,但并不相同。
这是 test_on_batch
的实现:https://github.com/tensorflow/tensorflow/blob/e5bf8de410005de06a7ff5393fafdf832ef1d4ad/tensorflow/python/keras/engine/training_v2_utils.py#L442
它在内部调用_standardize_user_data
来修复您的数据类型、数据形状等。
一旦,您用正确的形状和数据类型修复了 x
和 y
,结果非常接近,除了由于 delta
造成的一些小差异数值不稳定。
这是一个最小的示例,其中 test_on_batch
、train_on_batch
和 predict_on_batch
似乎在数值上结果一致。
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import tensorflow as tf
import numpy as np
# Loss definition
def mse(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true-y_pred))
# Model definition
model = Sequential()
model.add(Dense(1, input_shape = (10,)))
model.compile(optimizer = 'adam', loss = mse, metrics = [mse])
# Data creation
batch_size = 10
x = np.random.random_sample([batch_size,10]).astype('float32').reshape(-1, 10)
y = np.random.random_sample(batch_size).astype('float32').reshape(-1,1)
print(x.shape)
print(y.shape)
model.summary()
# running 5 iterations to check
for _ in range(5):
# Print loss before training
y_pred = model.predict_on_batch(x)
print("Before: " + str(mse(y,y_pred).numpy()))
# Print loss output from train_on_batch
print("Train output: " + str(model.train_on_batch(x,y)))
print(model.test_on_batch(x, y))
# Print loss after training
y_pred = model.predict_on_batch(x)
print("After: " + str(mse(y,y_pred).numpy()))
(10, 10)
(10, 1)
Model: "sequential_25"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_27 (Dense) (None, 1) 11
=================================================================
Total params: 11
Trainable params: 11
Non-trainable params: 0
_________________________________________________________________
Before: 0.30760005
Train output: [0.3076000511646271, 0.3076000511646271]
[0.3052913546562195, 0.3052913546562195]
After: 0.30529135
Before: 0.30529135
Train output: [0.3052913546562195, 0.3052913546562195]
[0.30304449796676636, 0.30304449796676636]
After: 0.3030445
Before: 0.3030445
Train output: [0.30304449796676636, 0.30304449796676636]
[0.3008604645729065, 0.3008604645729065]
After: 0.30086046
Before: 0.30086046
Train output: [0.3008604645729065, 0.3008604645729065]
[0.2987399995326996, 0.2987399995326996]
After: 0.29874
Before: 0.29874
Train output: [0.2987399995326996, 0.2987399995326996]
[0.2966836094856262, 0.2966836094856262]
After: 0.2966836
注意:train_on_batch
在计算损失后更新神经网络的权重,因此显然 train_on_batch
和 test_on_batch< 的损失
或 predict_on_batch
不会完全相同。正确的问题可能是为什么 test_on_batch
和 predict_on_batch
会给您的数据带来不同的损失。
关于python - 在tensorflow/keras中,为什么使用predict_on_batch训练后重新计算时train_on_batch输出损失不同?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61437538/