python - 使用 Keras 扩充 CSV 文件数据集

标签 python machine-learning keras neural-network data-augmentation

我正在开发一个已经在 Kaggle 实现的项目这与图像分类有关。我总共有 6 个类别要预测,分别是愤怒、快乐、悲伤等。我已经实现了 CNN 模型,目前仅使用 4 个类别(图像数量最多的类别),但我的模型过度拟合,我的验证准确率最高可达 53%,因此我尝试了多种方法,但似乎并没有提高我的准确率。现在我看到人们提到了一种叫做数据增强的东西,我想尝试一下,因为它似乎有可能提高准确性。但是我遇到了一个我无法弄清楚的错误。

数据集分布:

6_classes

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from matplotlib.pyplot import imread, imshow, subplots, show


def plot(data_generator):
    """
    Plots 4 images generated by an object of the ImageDataGenerator class.
    """
    data_generator.fit(df_training)
    image_iterator = data_generator.flow(df_training)

    # Plot the images given by the iterator
    fig, rows = subplots(nrows=1, ncols=4, figsize=(18,18))
    for row in rows:
        row.imshow(image_iterator.next()[0].astype('int'))
        row.axis('off')
    show()

x_train = df_training.drop("emotion",axis=1)
image = x_train[1:2].values.reshape(48, 48)
x_train = x_train.values.reshape(x_train.shape[0], 48, 48,1)
x_train = x_train.astype("float32")
image = image.astype("float32")
image = x_train[1:2].reshape(48, 48)

# Creating a dataset which contains just one image.
images = image.reshape((1, image.shape[0], image.shape[1]))

imshow(images[0])
show()
print(x_train.shape)
data_generator = ImageDataGenerator(rotation_range=90)
plot(data_generator)

错误:

ValueError: Input to .fit() should have rank 4. Got array with shape: (28709, 2305)

我已经将数据重新整形为 4d 数组,但由于某种原因,错误中它显示为 2d 数据。 这是 print(x_train.shape) => (28709, 48, 48, 1)

的形状

x_train 是数据集所在的位置,x_train[1:2] 访问一张图像。

P.s 您是否会推荐任何其他方法来根据此数据集提高我的准确性。有关我的数据集的更多问题,如果您不理解此部分代码中的某些内容,请告诉我。

最佳答案

您在 df_training 上使用 data_generator,而不是在 x_train 上。

关于如何避免过度拟合的更多想法: Tensorflow 有一个官方教程,其中有一些很好的建议: https://www.tensorflow.org/tutorials/keras/overfit_and_underfit

关于python - 使用 Keras 扩充 CSV 文件数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59664575/

相关文章:

python - "python3 configure.py": fatal error: 'qgeolocation.h' file not found期间的PyQt5错误

python - 在组任务完成之前, celery 链接执行子任务

machine-learning - 我可以使用 scikit-learn 管道仅转换特定变量吗?

machine-learning - 使用 Keras 进行视频预测(时间序列)

Python ctypes : How do I flush output from stderr?

python - cross_val_score 和 cross_val_predict 的区别

python - 池化后预期 Keras 形状不匹配

python - 默认 Adam 优化器在 tf.keras 中不起作用,但字符串 `adam` 可以

python - 由于使用 "lambda",无法加载保存的 Keras 模型

python - 一个列表包含另一个具有重复的列表