python - 在keras中正确使用fit_generator

标签 python keras neural-network

欢迎大家。尝试了解 fit_generator 在 keras 中的工作原理。

我有数据集,每个文件中有 100 个图像和 100 个标签。

我写了这个生成器:

def GenerateData(self):

    while True:

        complete_x1 = np.zeros((500, 50, 50, 3))
        complete_x2 = np.zeros((500, 50, 50, 3))
        complete_y1 = np.zeros((500, 3))
        complete_y2 = np.zeros((500, 2))

        done = 0

        while done < 500:

            data = np.load("{}/data_resized_{}.npy".format(self._patch, self._LastID))

            self.Log('\nLoad ALL data. ID: {} - Done: {}'.format(self._LastID, done))

            for data_x1, data_x2, data_y1, data_y2 in data:

                data_x1 = self.random_transform(data_x1)

                data_x2 = self.random_transform(data_x2)

                data_x1 = self.ImageProcessing(data_x1, 0)

                data_x2 = self.ImageProcessing(data_x2, 1)

                data_x1 = np.array(data_x1).astype('float32')
                data_x1 /= 255

                data_x2 = np.array(data_x2).astype('float32')
                data_x2 /= 255

                complete_x1[done] = data_x1
                complete_x2[done] = data_x2

                complete_y1[done] = data_y1
                complete_y2[done] = data_y2

                done += 1

            self._LastID += 1

            if self._LastID >= 1058:
                self._LastID = 0

        yield [np.array(complete_x1), np.array(complete_x2)], [np.array(complete_y1), np.array(complete_y2)]

我总共有 1058 个文件。结果是 105800 张带有标签的图像。

模型训练:

model.fit_generator(data.GenerateData(), samples_per_epoch=1058/500, nb_epoch=15, verbose=1, workers=1)

一切似乎都很好,但是!

在训练一开始,GenerateData 打印以下内容:

Load ALL data. ID: 0 - Done: 0

Load ALL data. ID: 1 - Done: 100

Load ALL data. ID: 2 - Done: 200

Load ALL data. ID: 3 - Done: 300

Load ALL data. ID: 4 - Done: 400

Load ALL data. ID: 5 - Done: 0

Load ALL data. ID: 6 - Done: 100

Load ALL data. ID: 7 - Done: 200

Load ALL data. ID: 8 - Done: 300

Load ALL data. ID: 9 - Done: 400

Load ALL data. ID: 10 - Done: 0

这发生在 ID 为 59 的文件之前。事实证明......它会跳过 59 文件之前的所有内容吗? 5900 张图片?

它只加载 500 张图像,之后就通过了 yield 并使用他完成的文件的 ID 重新开始,但火车不工作。

以下是第 59 个文件之后的内容:

Load ALL data. ID: 59 - Done: 400 1/2 [=============>................] - ETA: 4s - loss: 2.8177 - dense_18_loss: 2.0145 - dense_21_loss: 0.8032 - dense_18_acc: 0.2140 - dense_21_acc: 0.5780 Load ALL data. ID: 60 - Done: 0

Load ALL data. ID: 61 - Done: 100

Load ALL data. ID: 62 - Done: 200

Load ALL data. ID: 63 - Done: 300

Load ALL data. ID: 64 - Done: 400 2/2 [===========================>..] - ETA: 0s - loss: 2.7260 - dense_18_loss: 1.7077 - dense_21_loss: 1.0183 - dense_18_acc: 0.2720 - dense_21_acc: 0.5890 Load ALL data. ID: 65 - Done: 0

Load ALL data. ID: 66 - Done: 100

为什么会发生这种情况?

最佳答案

您之所以出现此行为,是因为您将 workers 设置为 1,并且数据生成任务和训练任务在单独的线程上运行。训练任务在主线程上运行,而数据生成任务在单独的线程上运行,其中线程的数量取决于 workers 参数。

如果 workers 参数为 0,数据生成器将在主线程上运行,结果将是您所期望的。

关于python - 在keras中正确使用fit_generator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56218490/

相关文章:

python-3.x - keras 多维度输入到 simpleRNN : dimension mismatch

python - 凯拉斯图像数据生成器 : problem with data and label shape

python - 哪种形状应该具有 LSTM NN 的输入和输出数据?

machine-learning - 单个网络 - 多个输出,还是多个网络 - 单输出?

python - 如何使用 entry_point 脚本启动调试器

python - env.password 在 fab 文件中设置,但进程仍然多次询问 sudo 密码

python - 打包C/Python项目时使用distutils的原因

python-3.x - 从目录流式传输图像并将预测与 tensorflow 中的文件名相关联

python - 从 python 32 位到 python 64 位

python - 评估期间 Experimenter 中的 tensorflow 混淆矩阵