python - 如何将 python 生成器更改为 Keras Sequence 对象?

标签 python tensorflow keras generator sequence

我正在研究回归问题。我的 CNN 训练数据的形状为 32x513x30 - 每批处理 32 个 513x30 实例,然后是 4810 批处理。

我将这些批处理保存在一个目录中,每个批处理名为“batch#number.npy”。

在使用 Python 生成器时,我不断收到来自 TensorFlow 的警告:

WARNING:tensorflow:Using a generator with use_multiprocessing=True and multiple workers may duplicate your data. Please consider using the keras.utils.Sequence class.

我想出了如何使用 Python 生成器加载它们。但是,在使用多处理时,建议使用 Keras 的 Sequence 类:https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence

不幸的是,这对我来说事情变得有点太复杂了。 这是我一直在使用的代码:

def batch_gen(num):

    os.chdir('mydirectory/train')

    for n in num:
        placeholder = np.load('batch#' + str(n) + '.npy')
        X = placeholder[:,:513,:]
        Y1= placeholder[:,513:,:]

        Y = X * Y1

        X = X / normalization # normalize X
        X = scale_mag*X.astype(np.float32)

        Y = Y / normalization 
        Y = scale_mag*Y.astype(np.float32)


        X = np.reshape(X,(32,513,30,1))
        Y = np.reshape(Y,(32,513,30,1))
        yield (X, Y)

my_gen = batch_gen(C)   # C is an array with indexes 1 to 4810 (looped by number of training epochs)

我使用生成器的方式是否导致我的数据在训练期间重复?如果是这样,我如何将其转换为 Sequence 类?

谢谢。

最佳答案

  class MyBatchGenerator(Sequence):
    def __init__(self, C):
        self.C = C

    def __len__(self):
        return len(self.C)

    def __getitem__(self, idx):   

        n = self.C[idx]
        os.chdir('mydirectory/train')

        placeholder = np.load('batch#' + str(n) + '.npy')
        X = placeholder[:,:513,:]
        Y1= placeholder[:,513:,:]

        Y = X * Y1

        X = X / normalization # normalize X
        X = scale_mag*X.astype(np.float32)

        Y = Y / normalization 
        Y = scale_mag*Y.astype(np.float32)


        X = np.reshape(X,(32,513,30,1))
        Y = np.reshape(Y,(32,513,30,1))
        return (X, Y)

关于python - 如何将 python 生成器更改为 Keras Sequence 对象?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55727074/

相关文章:

python - 使用 Beautiful Soup 的 Python 网络爬虫 BFS 算法?

Python subprocess.call 在没有 shell=True 的情况下不起作用

Python Librosa Keras 神经网络错误 : Too Many Indices For Array

python - keras中的精度计算不匹配

php - PHP 中 crc32b 的输出不等于 Python

python - Scipy Interpolate RectBivariateSpline 构造函数返回错误

python - 在 windows 上安装 tensorflow

tensorflow - 无法在 TensorFlow 2 中加载模型权重

tensorflow - 使用 tensorflow 对象检测减少误报的方法有哪些?

python - 如何在 Django 的多个同时进行的 Keras 分类器 session 中进行预测?