python - 使用带有validation_split的ImageDataGenerator的每个类的训练样本数

标签 python machine-learning keras

使用 Keras,我在 X 中有图像和 Y 中的标签。然后我这样做:

 train_datagen = ImageDataGenerator(validation_split = 0.25)

 train_generator = train_datagen.flow(X, Y, subset = 'training')

我的问题是:当 train_generator用于fit_generator内对于一个模型,每个类中有多少样本实际上被作为训练样本?

例如,如果我有 3 个类别的 1000 个 (x, y) 对:A 类 500 个,B 类 300 个,C 类 200 个,则 A、B 和 C 类中有多少个样本 fit_generator真的将其视为训练样本吗?或者我们能做的就是:500*(1.0 - 0.25) 等等?

最佳答案

如果我们检查the relevant part of the source code ,我们会意识到 X(和 y)中的最后一个 validation_split * num_samples 样本将用于验证,其他样本将用于验证用于训练:

split_idx = int(len(x) * image_data_generator._validation_split)

# ...
if subset == 'validation':
    x = x[:split_idx]
    x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
    if y is not None:
        y = y[:split_idx]
else:
    x = x[split_idx:]
    x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
    if y is not None:
        y = y[split_idx:]

因此,如果您想确保训练和验证子集中的类比例相同(即,Keras 在使用此功能时不保证这一点),这是您的责任。 Keras 唯一的东西 verifies训练和验证子集中至少包含每个类别的一个样本:

if not np.array_equal(
        np.unique(y[:split_idx]),
        np.unique(y[split_idx:])):
    raise ValueError('Training and validation subsets '
                     'have different number of classes after '
                     'the split. If your numpy arrays are '
                     'sorted by the label, you might want '
                     'to shuffle them.')

因此,进行分层分割(即在训练和验证分割中保留每个类别的样本比例)的解决方案是使用 sklearn.model_selection.train_test_split使用stratify参数集:

from sklearn.model_selection import train_test_split

val_split = 0.25
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_split, stratify=y)

X = np.concatenate((X_train, X_val))
y = np.concatenate((y_train, y_val))

现在您可以将 validation_split=val_split 传递给 ImageDataGenerator,并保证训练和验证子集中的类比例相同。

关于python - 使用带有validation_split的ImageDataGenerator的每个类的训练样本数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53842547/

相关文章:

带有副作用的 Python 闭包

python - TensorFlow:向 LSTM 添加正则化

python - 在CNN代码中添加图像增强功能以​​提高准确性

python - 需要构建Keras子模型

python - 如何在Python中将类导入到同一文件中的其他类中

python - 根据日期和列值重新索引 Pandas 数据框

python - 怎么说......当字段是数字时匹配......在mongodb中?

python - python/scikit-learn 中距离计算的稀疏实现

image - 如何从头开始创建和格式化图像数据集以用于机器学习?

python - 索引 1 超出了 python 中尺寸为 1 的轴 0 的范围