python - LSTM 模型在第一个 epoch 后的 val_acc 为 1.0?

标签 python tensorflow machine-learning keras lstm

我正在使用 LSTM 生成新闻标题。它应该根据序列中的前一个字符来预测下一个字符。我有超过一百万个新闻标题的文件,但出于速度原因我选择查看其中随机选择的 100,000 个。

当我尝试训练我的模型时,仅在第一个 epoch 中,它就达到了 1.0 验证准确度和 0.9986 训练准确度。这当然不可能是正确的。我不认为缺乏数据是问题所在,因为 90000 个训练数据点应该绰绰有余。这看起来不仅仅是你的基本过度拟合。它还花费了似乎过多的时间(每个时期大约 2.5 分钟),但我以前从未使用过 LSTM,所以我不确定训练时间会怎样。是什么导致我的模型表现如此?

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
Import Libraries Section
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
import csv
import numpy as np
from sklearn.model_selection import train_test_split
from keras.preprocessing.text import Tokenizer
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dropout, Dense  
import datetime
import matplotlib.pyplot as plt

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
Load Data Section
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
headlinesFull = []
with open("abcnews-date-text.csv", "r") as csv_file:
    csv_reader = csv.DictReader(csv_file, delimiter=',')
    for lines in csv_reader:
        headlinesFull.append(lines['headline_text'])

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
Pretreat Data Section
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# shuffle and select 100000 headlines
np.random.shuffle(headlinesFull)
headlines = headlinesFull[:100000]

# add spaces to make ensure each headline is the same length as the longest headline
max_len = max(map(len, headlines))
headlines = [i + " "*(max_len-len(i)) for i in headlines]

# integer encode sequences of words
# create the tokenizer 
t = Tokenizer(char_level=True) 
# fit the tokenizer on the headlines 
t.fit_on_texts(headlines)
sequences = t.texts_to_sequences(headlines)

# vocabulary size
vocab_size = len(t.word_index) + 1

# separate into input and output
sequences = np.array(sequences)
X, y = sequences[:,:-1], sequences[:,-1]     
y = to_categorical(y, num_classes=vocab_size)
seq_len = X.shape[1]

# split data for validation
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
Define Model Section
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# define model
model = Sequential()
model.add(Embedding(vocab_size, 50, input_length=seq_len))
model.add(LSTM(100, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(100))
model.add(Dropout(0.2))
model.add(Dense(100, activation='relu'))
model.add(Dense(vocab_size, activation='softmax'))
print(model.summary())
# compile model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
Train Model Section
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# fit model
model.fit(X_train, y_train, validation_data=(X_test, y_test), batch_size=128, epochs=1)

Train on 90000 samples, validate on 10000 samples
Epoch 1/1
90000/90000 [==============================] - 161s 2ms/step - loss: 0.0493 - acc: 0.9986 - val_loss: 2.3842e-07 - val_acc: 1.0000

最佳答案

通过观察代码,我可以推断出,

  • 您正在使用空格作为填充字符串来匹配最大值 标题长度,headlines = [i + ""*(max_len-len(i)) for i in headers]
  • 只有在所有标题达到最大长度后,标题才会转换为序列,并进行输入输出分割。
  • 因此,对于大多数输入,最后一个单词或输出(或最后一个单词) 数字序列)将是相同的填充符,这就是为什么你是 即使在一个时期之后也能获得如此高的准确性。

解决方案:

您可以在标题开头添加填充符,而不是在末尾添加填充符。

headlines = [" "*(max_len-len(i)) + i for i in headlines]

或者,在将标题拆分为 X 和 Y 后,在每个输入的末尾添加填充符。

关于python - LSTM 模型在第一个 epoch 后的 val_acc 为 1.0?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59133453/

相关文章:

python - Tensorflow 中的多维聚集

javascript - Perceptron Javascript 不一致

python - 如何在 tensorflow 中进行最小池化?

python - 从具有匹配词的主列表生成列表,无论顺序如何

python - 如何使用python检测进程是否正在运行

python - 更改数据框中多个位置的最快方法

python - Keras 可以像 tensorflow 数据集那样预取数据吗?

php - 无法将 Azure ML API 与 PHP 集成

math - 为什么机器学习不能识别素数?

python - 从网格中删除破折号