python - 如何为 Keras 准备数据集?

标签 python machine-learning keras

动机

通过 Keras 运行一组标记向量神经网络。

例子

查看 Keras 数据集示例 mnist:

keras.datasets import mnist
(x_tr, y_tr), (x_te, y_te) = mnist.load_data()
print x_tr.shape

它似乎是一个 3 维的 numpy 数组:

(60000, 28, 28)
  • 第一维用于样本
  • 每个样本特征的第 2 和第 3

尝试

构建标记向量:

X_train = numpy.array([[1] * 128] * (10 ** 4) + [[0] * 128] * (10 ** 4))
X_test = numpy.array([[1] * 128] * (10 ** 2) + [[0] * 128] * (10 ** 2))

Y_train = numpy.array([True] * (10 ** 4) + [False] * (10 ** 4))
Y_test = numpy.array([True] * (10 ** 2) + [False] * (10 ** 2))

X_train = X_train.astype("float32")
X_test = X_test.astype("float32")

Y_train = Y_train.astype("bool")
Y_test = Y_test.astype("bool")

训练代码

model = Sequential()
model.add(Dense(128, 50))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(50, 50))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(50, 1))
model.add(Activation('softmax'))

rms = RMSprop()
model.compile(loss='binary_crossentropy', optimizer=rms)

model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
          show_accuracy=True, verbose=2, validation_data=(X_test, Y_test))

score = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

结果

Test score: 13.9705320154
Test accuracy: 1.0

为什么对于这样一个简单的数据集我会得到如此糟糕的结果? 我的数据集格式不正确吗?

谢谢!

最佳答案

只有一个输出节点的 softmax 没有多大意义。如果您将 model.add(Activation('softmax')) 更改为 model.add(Activation('sigmoid')),您的网络运行良好。

或者,您也可以使用两个输出节点,其中 1, 0 表示 True 的情况,0, 1 表示错误。然后你可以使用 softmax 层。您只需相应地更改您的 Y_trainY_test

关于python - 如何为 Keras 准备数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31880720/

相关文章:

python - Sklearn SVM - 如何获取错误预测列表?

r - 在 R 中运行 kernlab 包的 ksvm 时出现此错误意味着什么

python - ValueError : `class_weight` must contain all classes in the data. 类{1,2,3}存在于数据中但不存在于 `class_weight`

python - Tensorflow 自定义指标 : SensitivityAtSpecificity

python - 为什么神经网络不学习?

c# - 例如,是否有一些 .NET 机器学习库可以为问题建议标签?

python - Django 教程 : What is get_queryset and why "model = poll" isn't needed?

python - 如何在pyspark sql中保存一个表?

java - 带有 Web 服务的 SMSlib

python - 将 xml 文件转换为 pandas 数据框