简短版本
给定一个生成器采样,例如3
输入和1
标签,如何定义我的 Tensorflow 数据集管道来获取批量 K * 3
输入和K * 1
标签?
更长版本
上下文
我正在使用 Triplet 网络,并且想要调整我当前的输入管道以使用 Tensorflow 数据集。
就我而言,批处理由 N
组成(例如图像)和 N // 3
标签(假设 N % 3 == 0
),每个标签应用于 3 个连续输入,例如
labels = [compute_label(inputs[3*i], inputs[3*i+1], inputs[3*i+2]) for i in range(N // 3)]
与compute_label(*args)
一个简单的函数,可以使用 Tensorflow 运算或基本 Python 来实现。
为了让事情变得更复杂,必须对输入元素进行 3 × 3 采样(例如,我们希望 inputs[3*i]
与 inputs[3*i+1]
相似,而与 inputs[3*i+2]
不同):
for i in range(N // 3):
inputs[3*i], inputs[3*i+1], inputs[3*i+2] = sample_triplet(i)
问题
根据我的具体情况重新制定较短的问题:
鉴于这两个函数sample_triplet()
和compute_label()
,如何使用 Tensorflow 数据集构建输入管道,以使用 N
构建批处理输入和N // 3
标签?
我尝试了 tf.data.Dataset.from_generator()
的几种组合和tf.data.Dataset.flat_map()
但找不到一种方法来压平 N // 3
中的批量输入三胞胎到N
样本,仅输出 N // 3
批处理标签。
我发现的一个解决方案是通过计算 tf.data.Dataset.from_generator()
中的标签来“作弊”并将每个标签平铺 3 次,以便能够使用 tf.data.Dataset.flat_map()
在三元组输入+标签上。作为批处理后处理步骤,我然后“挤压” N
重复的标签返回N // 3
.
当前解决方案的示例
import tensorflow as tf
import numpy as np
def sample_triplet():
# Sampling our elements, here as [class, random_val] elements:
anchor_class = puller_class = pusher_class = np.random.randint(0, 10)
while pusher_class == anchor_class:
# we want the pusher to be of a different class
pusher_class = np.random.randint(0, 10)
anchor = np.array([anchor_class, np.random.randint(0, 5)])
puller = np.array([puller_class, np.random.randint(0, 5)])
pusher = np.array([pusher_class, np.random.randint(0, 5)])
# Stacking the triplets, to then flat_map as a batch:
triplet_inputs = np.stack((anchor, puller, pusher), axis=0)
# Returning also the classes to compute the label afterwards:
triplet_classes = np.stack((anchor_class, puller_class, pusher_class), axis=0)
return triplet_inputs, triplet_classes
def compute_labels(triplet_classes):
# Computing the label, e.g. distance between the anchor and pusher classes:
label = np.abs(triplet_classes[0] - triplet_classes[2])
return label
def triplet_generator():
while True:
triplet = sample_triplet()
# Current solution: computing the label here too,
# stacking it 3 times so that flat_map works,
# then afterwards removing the duplicates:
triplet_inputs = triplet[0]
triplet_label = compute_labels(triplet[1])
yield triplet_inputs,
np.stack((triplet_label, triplet_label, triplet_label), axis=0)
def squeeze_triplet_labels(*batch):
# Removing the duplicate labels,
# going from a batch of (N inputs, N labels) to (N inputs, N // 3 labels)
squeezed_labels = batch[-1][::3]
new_batch = (*batch[:-1], squeezed_labels)
return new_batch
batch_size = 30
assert(batch_size % 3 == 0)
sess = tf.InteractiveSession()
train_dataset = (tf.data.Dataset
.from_generator(triplet_generator, (tf.int32, tf.float32), ([3, 2], [3]))
.flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x))
.batch(batch_size))
next_training_batch = train_dataset.make_one_shot_iterator().get_next()
next_proper_training_batch = squeeze_triplet_labels(*next_training_batch)
batch = sess.run(next_proper_training_batch)
print("inputs shape: {} ; label shape: {}".format(batch[0].shape, batch[1].shape))
# >> inputs shape: (30, 2) ; label shape: (10,)
最佳答案
一个简单的解决方案是创建 2 个 Dataset 对象,一个用于标签,一个用于数据,然后按 3 个一组对数据进行批处理,并使用 tf.data.interleave 来交错两个数据集一起,产生您想要的结果。
如果这不容易做到,那么您可以尝试以下将一个元素映射到多个元素的过程。您必须创建一批 3 个元素(带有 3 个标签),然后在映射函数中将其拆分为 3 组数据,每组数据针对您收到的一个标签。这样做的方法是在下面的SO问题中,尽管它比第一个建议更复杂一些:
In Tensorflow's Dataset API how do you map one element into multiple elements?
关于python - Tensorflow 数据集 - 给定生成器输出 1 个标签的 X 个输入,如何构建批处理?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49799998/