嗨,我不了解 keras fit_generator 文档。
我希望我的困惑是理性的。
有一个batch_size
还有分批训练的概念。使用 model_fit()
,我指定一个 batch_size
128 个。
对我来说,这意味着我的数据集将一次输入 128 个样本,从而大大减轻了内存。只要我有时间等待,它就应该允许训练 1 亿个样本数据集。毕竟,keras 一次只能“处理”128 个样本。对?
但我高度怀疑指定 batch_size
独自一人不会做我想做的事。大量内存仍在使用中。为了我的目标,我需要分批训练 128 个示例。
所以我猜这就是 fit_generator
做。我真的很想问为什么不batch_size
真的像它的名字所暗示的那样工作吗?
更重要的是,如果 fit_generator
需要,我在哪里指定 batch_size
?文档说无限循环。
生成器对每一行循环一次。我如何一次循环超过 128 个样本并记住我上次停止的位置并在下次 keras 要求下一批的起始行号时记忆它(在第一批完成后将是第 129 行)。
最佳答案
您将需要在生成器内部以某种方式处理批量大小。这是一个生成随机批处理的示例:
import numpy as np
data = np.arange(100)
data_lab = data%2
wholeData = np.array([data, data_lab])
wholeData = wholeData.T
def data_generator(all_data, batch_size = 20):
while True:
idx = np.random.randint(len(all_data), size=batch_size)
# Assuming the last column contains labels
batch_x = all_data[idx, :-1]
batch_y = all_data[idx, -1]
# Return a tuple of (Xs,Ys) to feed the model
yield(batch_x, batch_y)
print([x for x in data_generator(wholeData)])
关于tensorflow - keras中的fit_generator : where is the batch_size specified?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43780193/