我已经使用 keras 编写了 ANN 分类器,现在我正在学习在 keras 中编写 RNN 以进行文本和时间序列预测。在网上搜索了一段时间后,我发现了这个 tutorial作者:Jason Brownlee,对于 RNN 的新手来说很不错。原始文章使用 IMDb 数据集通过 LSTM 进行文本分类,但由于其数据集较大,我将其更改为小型短信垃圾邮件检测数据集。
# LSTM with dropout for sequence classification in the IMDB dataset
import numpy
from keras.datasets import imdb
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
import pandaas as pd
from sklearn.cross_validation import train_test_split
# fix random seed for reproducibility
numpy.random.seed(7)
url = 'https://raw.githubusercontent.com/justmarkham/pydata-dc-2016-tutorial/master/sms.tsv'
sms = pd.read_table(url, header=None, names=['label', 'message'])
# convert label to a numerical variable
sms['label_num'] = sms.label.map({'ham':0, 'spam':1})
X = sms.message
y = sms.label_num
print(X.shape)
print(y.shape)
# load the dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
top_words = 5000
# truncate and pad input sequences
max_review_length = 500
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)
X_test = sequence.pad_sequences(X_test, maxlen=max_review_length)
# create the model
embedding_vecor_length = 32
model = Sequential()
model.add(Embedding(top_words, embedding_vecor_length, input_length=max_review_length, dropout=0.2))
model.add(LSTM(100, dropout_W=0.2, dropout_U=0.2))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())
model.fit(X_train, y_train, nb_epoch=3, batch_size=64)
# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))
我已成功将数据集处理为训练集和测试集,但现在我应该如何为此数据集建模我的 RNN?
最佳答案
在训练神经网络模型之前,您需要将原始文本
数据表示为数值向量
。为此,您可以使用 scikit-learn
提供的 CountVectorizer
或 TfidfVectorizer
。从原始文本格式转换为数字向量表示后,您可以训练 RNN/LSTM/CNN 来解决文本分类问题。
关于python - 如何使用 keras RNN 对数据集中的文本进行分类?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41322243/