python - 如何让TensorFlow的map()函数返回多个值?

标签 python tensorflow data-augmentation

我正在尝试编写一个函数来增强数据集中的图像。我能够成功增强现有图像并返回它,但我希望能够对单个图像进行多次增强并单独返回这些增强图像,然后将它们添加到原始数据集。

增强功能:

def augment_data(image, label):

h_flipped_image = tf.image.flip_left_right(image)
v_flipped_image = tf.image.flip_up_down(image)

return h_flipped_image, label

map 功能:

train_ds = train_ds.map(augment_data)

train_ds 是 tf.data 数据集,具有以下形状:

<PrefetchDataset shapes: ((None, 224, 224, 3), (None, 238)), types: (tf.float32, tf.bool)>

如何使 map 函数返回多个值,例如,我可以同时返回 h_flipped_image 和 v_flipped_image 并将它们添加到 train_ds 数据集?

最佳答案

事实证明我没有从正确的方向看问题。我意识到我毕竟不需要将增强样本包含在数据集中。相反,我选择在训练过程中增强数据。

由于训练过程需要多个时期,我可以在网络需要图像之前增强图像。我通过修改我的augment_data函数来做到这一点,所以它现在有随机机会执行某种增强。每个时期都会对图像执行随机的增强组合,从而每次为网络生成不同的输入图像。

def augment_data(image, label):

rand = tf.random.uniform([])
if(rand > 0.5):
    image = tf.image.flip_up_down(image)

# Additional augmented techniques should be defined here

return image, label

确保使用 TensorFlow 函数来生成随机数。由于 TensorFlow 解释 Python 代码的方式,简单的 random.random() 将不起作用。

关于python - 如何让TensorFlow的map()函数返回多个值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59359909/

相关文章:

java - 使用 android 客户端和 python 服务器在 DataOutputStream 和 DataInputStream 之间滞后

python - 为什么在 Python 中是 ‘==‘ coming before ‘in’?

python - tf.nn.l2_loss 和 tf.contrib.layers.l2_regularizer 是否与在 tensorflow 中添加 L2 正则化的目的相同?

python - 在python中定义一个类的 "boolness"

python - 如何获取包含张量中每个元素的列表?

python - TensorFlow,为什么选择 python 语言?

python - CondaHTTPError - 安装 NLTK 时出现 SSL 错误

python - PyTorch 在 TensorDataset 上进行转换

python - 在 PyTorch 中显示增强图像的示例

python - 如何扩充 Tensorflow 数据集中的数据?