我知道已经有很多人问了类似的问题,我尝试了那里提到的所有内容(如果它适用于我的情况),但没有任何帮助。回归训练模型如下:
from keras.models import Sequential
from keras.layers.core import Dense , Dropout
from keras import backend as K
model = Sequential()
model.add(Dense(units = 128, kernel_initializer = "uniform", activation = "relu", input_dim = 28))
model.add(Dropout(rate = 0.2))
model.add(Dense(units = 128, kernel_initializer = "uniform", activation = "relu"))
model.add(Dropout(rate = 0.2))
model.add(Dense(units = 1, kernel_initializer = "uniform", activation = "relu"))
model.compile(optimizer = "rmsprop", loss = root_mean_squared_logarithmic_error)
model.fit(train_set, labels, batch_size = 32, epochs = 30)
使用下面定义的损失函数,这会导致:
Epoch 12/30
27423/27423 [==============================] - 2s - loss: 0.4143
Epoch 13/30
27423/27423 [==============================] - 1s - loss: 0.4070
Epoch 14/30
27423/27423 [==============================] - 1s - loss: nan
如果我使用标准mean_squared_error
损失函数loss = nan
没有发生。如果以下两个自定义loss functions
之一使用(当然这些是我想要运行的)loss = nan
发生在某个时刻。
def root_mean_squared_error(y_true, y_pred):
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
def root_mean_squared_logarithmic_error(y_true, y_pred):
y_pred_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
y_true_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
return K.sqrt(K.mean(K.square(y_pred_log - y_true_log), axis = -1))
使用 root_mean_squared_logarithmic_error
进行 10 折交叉验证损失函数loss = nan
通常发生在中间,某些折叠仅在最后一个时期发生(总共发生了 5 次)。一倍后,损失收敛到 15.6132
并在所有剩余的纪元中保留在那里。 4折完成无loss = nan
发生。
输入数据已更正为 nans
和异常值。我尝试了几种不同的重新缩放方法,但都没有效果
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import Normalizer
我还改变了输入数据(使用数据子集进行测试),在这种情况下 loss = nan
发生在每个子集中(甚至每两列组合)。我还改变了 neurons
,dropout
,optimizer
(至 'adam'
)和 batch_size
。
感谢您的想法,我感谢您的每一次帮助!
最佳答案
将 abs()
添加到损失函数帮助我解决了这个问题。
def root_mean_squared_error(y_true, y_pred):
return np.abs(K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1)))
关于python - Nan当用RMSE/RMSLE损失函数训练模型时,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43951554/