python - Tensorflow 数据集 - 给定生成器输出 1 个标签的 X 个输入,如何构建批处理?

标签 python tensorflow machine-learning deep-learning tensorflow-datasets

简短版本

给定一个生成器采样,例如3输入和1标签,如何定义我的 Tensorflow 数据集管道来获取批量 K * 3输入和K * 1标签?


更长版本

上下文

我正在使用 Triplet 网络,并且想要调整我当前的输入管道以使用 Tensorflow 数据集。

就我而言,批处理由 N 组成(例如图像)和 N // 3标签(假设 N % 3 == 0 ),每个标签应用于 3 个连续输入,例如

labels = [compute_label(inputs[3*i], inputs[3*i+1], inputs[3*i+2]) for i in range(N // 3)]

compute_label(*args)一个简单的函数,可以使用 Tensorflow 运算或基本 Python 来实现。

为了让事情变得更复杂,必须对输入元素进行 3 × 3 采样(例如,我们希望 inputs[3*i]inputs[3*i+1] 相似,而与 inputs[3*i+2] 不同):

for i in range(N // 3):
    inputs[3*i], inputs[3*i+1], inputs[3*i+2] = sample_triplet(i)

问题

根据我的具体情况重新制定较短的问题:

鉴于这两个函数sample_triplet()compute_label() ,如何使用 Tensorflow 数据集构建输入管道,以使用 N 构建批处理输入和N // 3标签?

我尝试了 tf.data.Dataset.from_generator() 的几种组合和tf.data.Dataset.flat_map()但找不到一种方法来压平 N // 3 中的批量输入三胞胎到N样本,仅输出 N // 3批处理标签。

我发现的一个解决方案是通过计算 tf.data.Dataset.from_generator() 中的标签来“作弊”并将每个标签平铺 3 次,以便能够使用 tf.data.Dataset.flat_map()在三元组输入+标签上。作为批处理后处理步骤,我然后“挤压” N重复的标签返回N // 3 .

当前解决方案的示例

import tensorflow as tf
import numpy as np

def sample_triplet():
    # Sampling our elements, here as [class, random_val] elements:
    anchor_class = puller_class = pusher_class = np.random.randint(0, 10)
    while pusher_class == anchor_class:
        # we want the pusher to be of a different class
        pusher_class = np.random.randint(0, 10) 

    anchor = np.array([anchor_class, np.random.randint(0, 5)])
    puller = np.array([puller_class, np.random.randint(0, 5)])
    pusher = np.array([pusher_class, np.random.randint(0, 5)])

    # Stacking the triplets, to then flat_map as a batch:
    triplet_inputs = np.stack((anchor, puller, pusher), axis=0)
    # Returning also the classes to compute the label afterwards:
    triplet_classes = np.stack((anchor_class, puller_class, pusher_class), axis=0)
    return triplet_inputs, triplet_classes

def compute_labels(triplet_classes):
    # Computing the label, e.g. distance between the anchor and pusher classes:
    label = np.abs(triplet_classes[0] - triplet_classes[2])
    return label

def triplet_generator():
    while True:
        triplet = sample_triplet()

        # Current solution: computing the label here too, 
        # stacking it 3 times so that flat_map works,
        # then afterwards removing the duplicates:
        triplet_inputs = triplet[0]
        triplet_label = compute_labels(triplet[1])
        yield triplet_inputs, 
              np.stack((triplet_label, triplet_label, triplet_label), axis=0)

def squeeze_triplet_labels(*batch):
    # Removing the duplicate labels,
    # going from a batch of (N inputs, N labels) to (N inputs, N // 3 labels)
    squeezed_labels = batch[-1][::3]
    new_batch = (*batch[:-1], squeezed_labels)
    return new_batch


batch_size = 30
assert(batch_size % 3 == 0)
sess = tf.InteractiveSession()
train_dataset = (tf.data.Dataset
                 .from_generator(triplet_generator, (tf.int32, tf.float32), ([3, 2], [3]))
                 .flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x))
                 .batch(batch_size))

next_training_batch = train_dataset.make_one_shot_iterator().get_next()
next_proper_training_batch = squeeze_triplet_labels(*next_training_batch)
batch = sess.run(next_proper_training_batch)
print("inputs shape: {} ; label shape: {}".format(batch[0].shape, batch[1].shape))
# >> inputs shape: (30, 2) ; label shape: (10,)

最佳答案

一个简单的解决方案是创建 2 个 Dataset 对象,一个用于标签,一个用于数据,然后按 3 个一组对数据进行批处理,并使用 tf.data.interleave 来交错两个数据集一起,产生您想要的结果。

如果这不容易做到,那么您可以尝试以下将一个元素映射到多个元素的过程。您必须创建一批 3 个元素(带有 3 个标签),然后在映射函数中将其拆分为 3 组数据,每组数据针对您收到的一个标签。这样做的方法是在下面的SO问题中,尽管它比第一个建议更复杂一些:

In Tensorflow's Dataset API how do you map one element into multiple elements?

关于python - Tensorflow 数据集 - 给定生成器输出 1 个标签的 X 个输入,如何构建批处理?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49799998/

相关文章:

r - GBM R函数: get variable importance separately for each class

r - 如何从不平衡数据框架创建一个新的平衡数据框架,确保随机选择记录?

python - 如何在Python中存储用户输入并稍后调用?

python - Rabbitmq mgmt 上显示未知队列名称。使用 celery 时

python - 在 __init__.py 中导入包中使用的模块

python - 在没有 TensorFlow 占位符的情况下工作

python - logits 和 labels must be same size error 使用 SoftmaxCrossEntropyWithLogits

python - 在 Sphinx 中交叉引用 Python 对象有什么要求?

python - 计算对称矩阵的前 k 个(绝对值)特征值

tensorflow - 在 TensorFlow 中微调 Inception 模型