python - 使用 tf.data API 和样本权重进行训练

标签 python tensorflow tensorflow-datasets

我的所有训练图像都在 tfrecords 文件中。现在它们以如下标准方式使用:

dataset = dataset.apply(tf.data.experimental.map_and_batch(
            map_func=lambda x: preprocess(x, data_augmentation_options=data_augmentation), 
            batch_size=images_per_batch)

其中预处理返回来自 tfrecord 文件的解码图像和标签。

现在新情况。我还想要每个例子的样本权重。所以而不是

return image,label

在预处理中,应该是

return image, label, sample_weight

但是,这个sample_weight并不在tfrecord文件中。它是在训练开始时根据每个类别的示例数量计算的。基本上它是一个Python字典weights[label] = sample_weights。

问题是如何在 tf.data 管道中使用这些样本权重。因为 label 是一个 Tensor,所以它不能用于索引 Python 字典。

最佳答案

您的问题有一些不清楚,例如 x 是什么?如果您能发布完整的代码示例来说明您的问题,那就更好了。

我假设 x 是带有图像和标签的张量。如果是这样,您可以使用映射函数将样本权重张量添加到数据集中。如下(请注意,此代码未经测试):

def im_add_weight(image, label, sample_weight):
   #convert to tensor if they are not and make sure to us
   image= tf.convert_to_tensor(image, dtype= tf.float32)
   label = tf.convert_to_tensor(label, dtype= tf.float32)
   sample_weight = tf.convert_to_tensor(sample_weight, dtype= tf.float32)
   return image, label, sample_weight

dataset = dataset .map(
lambda image, label, sample_weight: tuple(tf.py_func(
    im_add_weight, [image, label,sample_weight], [tf.float32, tf.float32,tf.float32])))

关于python - 使用 tf.data API 和样本权重进行训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53653485/

相关文章:

python - tf.argmax() 用于多个索引 Tensorflow

python - 带有 ListDirectory 的 Tensorflow 数据集 API

python - 如何使用 TF1.3 中的新数据集 api 映射具有附加参数的函数?

python - 如何从 Python 生成唯一的 64 位整数?

python - 在 DjangoModelFactory 中使用 Django faker 创建的克隆模型字段

python - 无需手动输入即可读取文件名列表

python - 重复排列?

python - TensorFlow 估计器.预测 : Saving when as_iterable=True

machine-learning - Tensorflow 中 cifar10 数据集的 CNN

python - 使用 tf.Dataset 训练的模型进行推理