tensorflow - keras中的fit_generator : where is the batch_size specified?

标签 tensorflow keras

嗨,我不了解 keras fit_generator 文档。

我希望我的困惑是理性的。

有一个batch_size还有分批训练的概念。使用 model_fit() ,我指定一个 batch_size 128 个。

对我来说,这意味着我的数据集将一次输入 128 个样本,从而大大减轻了内存。只要我有时间等待,它就应该允许训练 1 亿个样本数据集。毕竟,keras 一次只能“处理”128 个样本。对?

但我高度怀疑指定 batch_size独自一人不会做我想做的事。大量内存仍在使用中。为了我的目标,我需要分批训练 128 个示例。

所以我猜这就是 fit_generator做。我真的很想问为什么不batch_size真的像它的名字所暗示的那样工作吗?

更重要的是,如果 fit_generator需要,我在哪里指定 batch_size ?文档说无限循环。
生成器对每一行循环一次。我如何一次循环超过 128 个样本并记住我上次停止的位置并在下次 keras 要求下一批的起始行号时记忆它(在第一批完成后将是第 129 行)。

最佳答案

您将需要在生成器内部以某种方式处理批量大小。这是一个生成随机批处理的示例:

import numpy as np
data = np.arange(100)
data_lab = data%2
wholeData = np.array([data, data_lab])
wholeData = wholeData.T

def data_generator(all_data, batch_size = 20):

    while True:        

        idx = np.random.randint(len(all_data), size=batch_size)

        # Assuming the last column contains labels
        batch_x = all_data[idx, :-1]
        batch_y = all_data[idx, -1]

        # Return a tuple of (Xs,Ys) to feed the model
        yield(batch_x, batch_y)

print([x for x in data_generator(wholeData)])

关于tensorflow - keras中的fit_generator : where is the batch_size specified?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43780193/

相关文章:

python - 从二进制观察中恢复概率分布 - 这种实现的缺陷的原因是什么?

python - Kerras : Found 39 images belonging to 3 classes. 找到属于 3 个类的 49 个图像

python - ImageDataGenerator 是否会向我的数据集添加更多图像?

python - 在没有 Numpy 数组的情况下获取 ValueError

machine-learning - 如何计算最佳批量大小

python - Tensorflow、py_func 或自定义函数

python - 培训不起作用。 "Skipping training since max_steps has already saved"

tensorflow - 卷积神经网络训练

python - Dice 损失和目标张量中不包含数据的样本?

windows - keras plot_model 告诉我安装 pydot