欢迎大家。尝试了解 fit_generator 在 keras 中的工作原理。
我有数据集,每个文件中有 100 个图像和 100 个标签。
我写了这个生成器:
def GenerateData(self):
while True:
complete_x1 = np.zeros((500, 50, 50, 3))
complete_x2 = np.zeros((500, 50, 50, 3))
complete_y1 = np.zeros((500, 3))
complete_y2 = np.zeros((500, 2))
done = 0
while done < 500:
data = np.load("{}/data_resized_{}.npy".format(self._patch, self._LastID))
self.Log('\nLoad ALL data. ID: {} - Done: {}'.format(self._LastID, done))
for data_x1, data_x2, data_y1, data_y2 in data:
data_x1 = self.random_transform(data_x1)
data_x2 = self.random_transform(data_x2)
data_x1 = self.ImageProcessing(data_x1, 0)
data_x2 = self.ImageProcessing(data_x2, 1)
data_x1 = np.array(data_x1).astype('float32')
data_x1 /= 255
data_x2 = np.array(data_x2).astype('float32')
data_x2 /= 255
complete_x1[done] = data_x1
complete_x2[done] = data_x2
complete_y1[done] = data_y1
complete_y2[done] = data_y2
done += 1
self._LastID += 1
if self._LastID >= 1058:
self._LastID = 0
yield [np.array(complete_x1), np.array(complete_x2)], [np.array(complete_y1), np.array(complete_y2)]
我总共有 1058 个文件。结果是 105800 张带有标签的图像。
模型训练:
model.fit_generator(data.GenerateData(), samples_per_epoch=1058/500, nb_epoch=15, verbose=1, workers=1)
一切似乎都很好,但是!
在训练一开始,GenerateData 打印以下内容:
Load ALL data. ID: 0 - Done: 0
Load ALL data. ID: 1 - Done: 100
Load ALL data. ID: 2 - Done: 200
Load ALL data. ID: 3 - Done: 300
Load ALL data. ID: 4 - Done: 400
Load ALL data. ID: 5 - Done: 0
Load ALL data. ID: 6 - Done: 100
Load ALL data. ID: 7 - Done: 200
Load ALL data. ID: 8 - Done: 300
Load ALL data. ID: 9 - Done: 400
Load ALL data. ID: 10 - Done: 0
这发生在 ID 为 59 的文件之前。事实证明......它会跳过 59 文件之前的所有内容吗? 5900 张图片?
它只加载 500 张图像,之后就通过了 yield 并使用他完成的文件的 ID 重新开始,但火车不工作。
以下是第 59 个文件之后的内容:
Load ALL data. ID: 59 - Done: 400 1/2 [=============>................] - ETA: 4s - loss: 2.8177 - dense_18_loss: 2.0145 - dense_21_loss: 0.8032 - dense_18_acc: 0.2140 - dense_21_acc: 0.5780 Load ALL data. ID: 60 - Done: 0
Load ALL data. ID: 61 - Done: 100
Load ALL data. ID: 62 - Done: 200
Load ALL data. ID: 63 - Done: 300
Load ALL data. ID: 64 - Done: 400 2/2 [===========================>..] - ETA: 0s - loss: 2.7260 - dense_18_loss: 1.7077 - dense_21_loss: 1.0183 - dense_18_acc: 0.2720 - dense_21_acc: 0.5890 Load ALL data. ID: 65 - Done: 0
Load ALL data. ID: 66 - Done: 100
为什么会发生这种情况?
最佳答案
您之所以出现此行为,是因为您将 workers
设置为 1,并且数据生成任务和训练任务在单独的线程上运行。训练任务在主线程上运行,而数据生成任务在单独的线程上运行,其中线程的数量取决于 workers
参数。
如果 workers
参数为 0,数据生成器将在主线程上运行,结果将是您所期望的。
关于python - 在keras中正确使用fit_generator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56218490/