python - 将 tf.dataset 中的每个样本映射到 id

标签 python tensorflow tensorflow2.0 tensorflow-datasets

出于测试目的,我想为 tf.dataset 中的每个样本附加一个 id。只需向上计数就足够了。

我的数据集的类型为 FlatMapDataset fwiw。

for entry in img_ds:
        print(entry.shape)

(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
...

我尝试的是有一个映射函数,在其内部定义一个计数器并向上计数:

@staticmethod
    def map_to_id(img):
        try:
            ExperimentalPipeline.map_to_id.id_counter += 1
        except AttributeError:
            ExperimentalPipeline.map_to_id.id_counter = 0
        return img, ExperimentalPipeline.map_to_id.id_counter

然后使用 tf.data 中的 Dataset.map 将 id 附加到每个样本:

img_ds = img_ds.map(ExperimentalPipeline.map_to_id)

不幸的是,这不起作用,每个样本的 id 为零:

for i, id in img_ds:
        print(f"{i.shape}, {id}")

(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
...

我还注意到我的 map_to_id 函数仅被调用一次。

@staticmethod
def map_to_id(img):
    print("enter map_to_id")
    try:
        ExperimentalPipeline.map_to_id.id_counter += 1
    except AttributeError:
        print("caught exception")
        ExperimentalPipeline.map_to_id.id_counter = np.random.randint(1000)
    return img, ExperimentalPipeline.map_to_id.id_counter

enter map_to_id
caught exception
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889

我想我不明白 Dataset.map 应该如何工作。我认为它会获取正在调用的数据集中的每个样本,并以该样本作为参数调用提供的函数。
有人可以帮我解决这个问题吗?

最佳答案

TensorFlow 将运行一次 map 函数,将该函数编译为 TensorFlow 操作。然后这些操作(而不是原始的 python 函数)将应用于数据集的每个元素。如果你想为每个元素运行原始的 python 函数,你可以使用 py_function相反。

在这种特定情况下,您想要附加元素 ID,可以使用 Dataset.enumerate实现您的目标:

img_ds = img_ds.enumerate()

关于python - 将 tf.dataset 中的每个样本映射到 id,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60365843/

相关文章:

python - tensorflow 2 : shape mismatch when serialize and decode it back

tensorflow - 如何在 tensorflow 2.0 中使用层列表?

tensorflow - 计算输出相对于权重的梯度

python - 在构建 RPM 包时传送 *.so 和二进制文件

python - openerp错误AttributeError : 'int' object has no attribute 'iteritems'

python - Django 说 MySQL 不允许唯一的 CharFields 有一个 max_length > 255,但它允许

python - 在 tfhub 再训练脚本中计算 F1 分数、精度、召回率

python - 如何防止 tensorflow 分配整个 GPU 内存?

multithreading - Keras Tensorflow - 从多个线程进行预测时出现异常

python - Tensorflow 填充实现