我的所有训练图像都在 tfrecords 文件中。现在它们以如下标准方式使用:
dataset = dataset.apply(tf.data.experimental.map_and_batch(
map_func=lambda x: preprocess(x, data_augmentation_options=data_augmentation),
batch_size=images_per_batch)
其中预处理返回来自 tfrecord 文件的解码图像和标签。
现在新情况。我还想要每个例子的样本权重。所以而不是
return image,label
在预处理中,应该是
return image, label, sample_weight
但是,这个sample_weight并不在tfrecord文件中。它是在训练开始时根据每个类别的示例数量计算的。基本上它是一个Python字典weights[label] = sample_weights。
问题是如何在 tf.data 管道中使用这些样本权重。因为 label 是一个 Tensor,所以它不能用于索引 Python 字典。
最佳答案
您的问题有一些不清楚,例如 x 是什么?如果您能发布完整的代码示例来说明您的问题,那就更好了。
我假设 x 是带有图像和标签的张量。如果是这样,您可以使用映射函数将样本权重张量添加到数据集中。如下(请注意,此代码未经测试):
def im_add_weight(image, label, sample_weight):
#convert to tensor if they are not and make sure to us
image= tf.convert_to_tensor(image, dtype= tf.float32)
label = tf.convert_to_tensor(label, dtype= tf.float32)
sample_weight = tf.convert_to_tensor(sample_weight, dtype= tf.float32)
return image, label, sample_weight
dataset = dataset .map(
lambda image, label, sample_weight: tuple(tf.py_func(
im_add_weight, [image, label,sample_weight], [tf.float32, tf.float32,tf.float32])))
关于python - 使用 tf.data API 和样本权重进行训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53653485/