python - Tensorflow 2.0 : how to transform from MapDataset (after reading from TFRecord) to some structure that can be input to model. 适合

标签 python tensorflow tfrecord

我将训练和验证数据存储在两个单独的 TFRecord 文件中,其中存储了 4 个值:信号 A(float32 形状(150,)),信号 B(float32 形状(150,)),标签(标量 int64), id(字符串)。我的阅读解析功能是:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(sample_proto, raw_signal_description)

哪里SIGNALS是一个字典映射信号名称->信号形状。然后,我阅读了原始数据集:

training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')

并使用 map 来解析值:

training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)

显示 training_data 的标题或 val_data ,我得到:
<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>
这几乎和预期的一样。我还检查了一些值的一致性,它们似乎是正确的。

现在,我的问题:我如何从 MapDataset 中使用类似字典的结构获取可以作为模型输入的内容?

我的模型的输入是对(信号 A,标签),但将来我也会使用信号 B。

对我来说,最简单的方法似乎是在我想要的元素上创建一个生成器。就像是:

def data_generator(mapdataset):
    for sample in mapdataset:
        yield (sample['Signal A'], sample['label'])

但是,使用这种方法我失去了 Datasets 的一些便利性,例如批处理,并且也不清楚如何对 validation_data 使用相同的方法。 model.fit 的参数.理想情况下,我只会在 map 表示和数据集表示之间进行转换,它在信号 A 张量和标签对上进行迭代。

编辑:我的最终产品应该是类似于以下标题的东西:<TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)>但不一定TensorSliceDataset

最佳答案

您可以简单地在 parse 函数中执行此操作。例如:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)

    return parsed['Signal A'], parsed['label']

如果您 map这个函数超过了TFRecordDataset ,您将拥有一个元组数据集 (signal_a, label)而不是字典数据集。你应该可以把它放入 model.fit直接地。

关于python - Tensorflow 2.0 : how to transform from MapDataset (after reading from TFRecord) to some structure that can be input to model. 适合,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60601412/

相关文章:

python - 如何从 TFrecord 创建批处理用于 tensorflow 中的训练网络?

pytorch - 如何在pytorch中加载tfrecord?

python - 计算机视觉 : Opencv Counting small circles inside big circle

python - Google 云打印 - Python 代码示例 - GaiaLogin 返回 None

python - Rabbitmq 使用 Tornado 不阻塞消息

tensorflow - Adam优化器中的epsilon参数

machine-learning - 从 Tensorflow 中的损失函数中屏蔽样本

python - 在 django 应用程序中将 SSL 与 nginx 结合使用

python - Tensorflow 训练错误

python - 属性错误: 'Tensor' object has no attribute 'numpy' in Tensorflow 2. 1