我对如何在 keras
中使用 fit_generator
感到有点困惑。
举例来说:
- 我们有 10000 个数据点
- 我们想要运行 10 个 epoch
- 批量大小为 512
使用fit
我们只是:
x, y = load_data()
model.fit(x=x, y=y, batch_size=512, epochs=10)
其中 load_data
加载所有数据。
现在如何使用 fit_generator
执行相同操作。
我不清楚使用 fit_generator
时它是如何处理的。如果我有以下生成器:
def data_generator():
for x, y in load_data_per_line():
yield x, y
在上面的生成器中,每次产生
一个数据点。并且:
def data_generator_2():
x_output = []
y_output = []
i = 0
for x, y in load_data_per_line():
x_output[i] = x
y_output[i] = y
i = i + 1
if i == batch_size:
yield x_output, y_output
i = 0
x_output = []
y_output = []
在上面的生成器中,每次产生
个批量大小的数据点(在本例中为 512 个)。
要实现与 fit
相同的效果,但使用 fit_generator
:
model.fit_generator(data_generator(), steps_per_epoch=10000 / 512, epochs=10)
或
model.fit_generator(data_generator_2(), steps_per_epoch=10000 / 512, epochs=10)
或者两者都是错误的(fit_generator
和 data_generator
)?如果其中任何一个是正确的,是否可以保证所有数据点都将被处理并且也将按顺序处理?
任何见解都是有用的
最佳答案
生成器 2 几乎没问题,但它应该更好地返回 numpy 数组:
yield np.asarray(x_output),np.asarray(y_output)
而且,它应该是无限的:
while True:
#the code inside to loop infinitely
第一个不会返回批处理并且会失败。
您可能会在 steps_per_epoch
中遇到问题,因为 10000 不是 512 的倍数。您需要整数步长。您可以在生成器内检查 if i == 10000:
并传递较小的批处理作为最后一批。
那么您就得到了 (10000//512) + (10000 % 512)
个步骤或批处理。
所有批处理都会按顺序读取,但keras会自动打乱这些批处理的内容,使用suffle=False
。如果您使用多线程(并非如此),那么您需要创建线程安全生成器或使用 keras Sequence
。
关于python - 如何: fit_generator in keras,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46570172/