python - 等效于 csv 文件的 Keras ImageDataGenerator

标签 python python-3.x tensorflow keras

我有一堆数据排序在文件夹中,如下图所示:

enter image description here

我需要构建一个 DataIterator 以便将数据拟合到神经网络模型中。当数据是图像时,我找到了很多解决这个问题的例子,使用 Keras 类 ImageDataGenerator 及其方法 flow_from_directory,但当数据是 csv 结构时不是。

每个 csv 文件都是一个 512x11 float 组,表示传感器所需的功率。我考虑过将这些 CSV 中的每一个转换为图像格式,然后应用 ImageDataGenerator 类,但压缩会导致信息丢失(在图像中,每个值由 8 位整数表示,而我的数据是 32 位 float )。

那么,在 Keras 中有一个等效于 ImageDataGenerator 的工具来加载 csv 文件而不是图像?

最佳答案

是的,您可以通过继承 Sequence 对象来编写自己的生成器。这个想法是你用两列组成某种数据框(例如 pandas 数据框):一列用于标签,另一列用于你的 csv 文件的路径。您的数据生成器将使用此文件来确定数据集的长度(csv 文件的数量)并批量读取文件并将它们传递给模型。

您的代码可能如下所示:

class DataSequence(Sequence):
    """
    Keras Sequence object to train a model on a list of csv files
    """
    def __init__(self, df, batch_size, mode='train'):
        """
        df = dataframe with two columns: the labels and a list of filenames
        """
        self.df = df
        self.bsz = batch_size
        self.mode = mode

        # Take labels and a list of image locations in memory
        self.labels = self.df['label'].values
        self.file_list = self.df['file_names']

    def __len__(self):
        return int(math.ceil(len(self.df) / float(self.bsz)))

    def on_epoch_end(self):
        self.indexes = range(len(self.im_list))
        if self.mode == 'train':
            # Shuffles indexes after each epoch if in training mode
            self.indexes = random.sample(self.indexes, k=len(self.indexes))

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return self.labels[idx * self.bsz: (idx + 1) * self.bsz]

    def get_batch_features(self, idx):
        # Fetch a batch of inputs
        return np.array([READ_CSV_FUNCTION(f) for f in self.file_list[idx * self.bsz: (1 + idx) * self.bsz]])

    def __getitem__(self, idx):
        batch_x = self.get_batch_features(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_x, batch_y

您只需将 READ_CSV_FUNCTION 替换为您选择的函数即可读取和解析 csv 文件。

关于python - 等效于 csv 文件的 Keras ImageDataGenerator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53634924/

相关文章:

python - 在 Python 中,如何对列表进行切片以获得第一个元素,以及除最后一个元素之外的所有元素?

python - Tkinter 滚动条位置,yview_moveto() 似乎不起作用

machine-learning - 为什么在 tensorflow 的 MNIST 教程中 x 变量张量被重新整形为 -1?

python - Django 组选项列表

python - django 的登录 session

python - 通过抓取收集信息

python - 删除停用词/标点符号,标记并应用 Counter()

python - 如何在Python中从文本文件读取input()

python - Tensorflow.js Layers 模型和 Graph 模型有什么区别?

machine-learning - tf.nn.embedding_lookup 带有浮点输入?