python - 如何仅从Keras提供的MNIST数据集中选择特定数字?

标签 python filter keras deep-learning mnist

我目前正在使用Keras在MNIST数据集上训练前馈神经网络。我正在使用以下格式加载数据集
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
但是我只想用数字0和4训练我的模型,而不是全部。如何只选择2位数字?我对python相当陌生,可以弄清楚如何过滤mnist数据集...

最佳答案

Y_trainY_test为您提供图像的标签,您可以将它们与 numpy.where 一起使用,以过滤出带有0和4的标签子集。您所有的变量都是numpy数组,因此您可以轻松完成;

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

您可以使用这些过滤器按索引获取数组的子集。
X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

如果您对两个以上的标签感兴趣,则语法可能会因为where和or而变得冗长。因此,您也可以使用 numpy.isin 创建 mask 。
train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

您可以像以前一样使用这些掩码进行 bool 索引。

关于python - 如何仅从Keras提供的MNIST数据集中选择特定数字?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51202181/

相关文章:

python - 训练期间验证准确率突然下降

python - 在多个 GPU 上运行相同的模型,但向每个 GPU 发送不同的用户数据

python - Keras 线程安全吗?

python - 如何将 tar.gz github 文件添加到 setup.py

python - 如何在 Keras 的自定义批量训练中获得每个时期的损失?

python - 在 Python 中存储多个数组

python - 如何按对象计算 Pandas 组列中的不同值?

arrays - 在 ruby​​ 中将两个字符串过滤为 1 的最佳方法

javascript - 如何使用 Array.filter() 函数从数组中删除单个元素?

Python 从列表中过滤/删除 URL