python - 避免 LSTM 中的过度拟合

标签 python tensorflow machine-learning keras deep-learning

我已经使用 Keras 和 TensorFlow 编写了代码来识别循环数据集中的模式。我担心的是过度拟合以及如何避免过度拟合。现在,从损失值和准确性来看,我似乎已经过度拟合了。代码如下:

#importing libraries
import numpy as np
import pandas as pd
import os
import math
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from pandas import read_csv
from matplotlib import pyplot
import tensorflow.keras.backend as K
from keras.layers import Dense, Activation
from keras.layers.recurrent import LSTM
from keras import backend as K
from keras.models import Sequential
from sklearn.metrics import mean_squared_error
from keras.layers import InputLayer

# Reading dataset
df = pd.read_excel("concate35w270.xlsx")
df = df.astype('float32')
df.head()


scaler = MinMaxScaler(feature_range= (0,1))
df = scaler.fit_transform(df)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.25, random_state = 0)

#Looking last time step

def create_dataset(dataset, look_back=1):
    dataX, dataY = [], []
    for i in range(len(dataset)-look_back-1):
        a = dataset[i:(i+look_back), 0]
        dataX.append(a)
        dataY.append(dataset[i + look_back, 0])
    return np.array(dataX), np.array(dataY
# Reshaping dataset
# reshape into X=t and Y=t+1
look_back = 1
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)

look_back = 1
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)


# reshape input to be [samples, time steps, features]
trainX = np.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
testX = np.reshape(testX, (testX.shape[0], 1, testX.shape[1]))




# Network Architecture 

# create and fit the LSTM network
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.LSTM(1, input_shape=(1, look_back)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation = tf.nn.relu))
#model.add(tf.keras.layers.Dense(128, activation = tf.nn.relu))
model.add(tf.keras.layers.Dense(1, activation = 'sigmoid'))



def coeff_determination(y_true, y_pred):
    SS_res =  K.sum(K.square( y_true-y_pred ))
    SS_tot = K.sum(K.square( y_true - K.mean(y_true) ) )
    return ( 1 - SS_res/(SS_tot + K.epsilon()) )

model.compile(optimizer='sgd', 
         loss='mse',
         metrics = [coeff_determination])



model.fit(trainX, trainY, epochs = 30)

开始使用训练数据集拟合模型后,我看到了机器发出的信息:

Epoch 1/30
702543/702543 [==============================] - 64s 91us/sample - loss: 0.0376 - coeff_determination: 0.4673
Epoch 2/30
702543/702543 [==============================] - 61s 86us/sample - loss: 0.0015 - coeff_determination: 0.9791
Epoch 3/30
702543/702543 [==============================] - 60s 86us/sample - loss: 0.0014 - coeff_determination: 0.9802
Epoch 4/30
702543/702543 [==============================] - 64s 91us/sample - loss: 0.0013 - coeff_determination: 0.9812
Epoch 5/30
702543/702543 [==============================] - 68s 97us/sample - loss: 0.0013 - coeff_determination: 0.9820
Epoch 6/30
702543/702543 [==============================] - 67s 96us/sample - loss: 0.0012 - coeff_determination: 0.9827
Epoch 7/30
702543/702543 [==============================] - 67s 95us/sample - loss: 0.0012 - coeff_determination: 0.9834

我想我应该定义一个惩罚以避免过度拟合,我该怎么做? 所有帮助将不胜感激。

最佳答案

您的测试数据旨在监控模型对训练数据的过度拟合,因此您必须在.fit中插入validation_data参数方法如下:

model.fit(trainX, trainY, validation_data=(testX, testY), epochs=30)

详细信息可以在my answer here获取.

关于python - 避免 LSTM 中的过度拟合,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60055799/

相关文章:

python - 如何在 Colab 中检查预处理时间/速度?

python - 使用 to_categorical 转换 np.array 时的内存问题

python - 学习 MNIST 后对非 MNIST 图像进行分类

python - 不同日期的 pandas groupby 事件

python - 从两个具有匹配值的字典列表创建一个字典

python - 如何根据不同的列用 NA 填充名称

python - 一些 Python 对象未绑定(bind)到检查点值

python - OpenCV-检测到特征点后,如何获取特征点的x,y坐标

python - Tensorflow 对象检测 API : output boxes for probability less than 50%

python - 无法使用 Pandas 在 Python 中导入数据