python - 如何将两个 tf.data.Dataset 合并为一个具有已知比率的交替元素

标签 python tensorflow dataset

我有两个 tf.data.Dataset,我们称它们为 d1d2我想构建另一个包含 d1 元素的数据集和d2交替。用一个例子更容易解释。 让我们说:

d1 = [0,1,2,3,4,5,6,7,...] # it is not a list, just the content of the dataset

d2 = ["a", "b", "c", "d",... ]

我有一对指定每个数据集中连续元素的数量(例如(3,1))。

我正在寻找的结果是:

result = [0, 1, 2, "a", 3, 4, 5, "b", 6, 7, 8, "c"...]

编辑:d1和d2是tf.data.Dataset类的对象。上面的示例仅显示了数据集的内容,但它不是代码。

最佳答案

假设 TF 2.0。该技巧基于batch接下来是数据集交错和 unbatch .

import tensorflow as tf 

# input datasets
d1 = tf.data.Dataset.from_tensors([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unbatch()
d2 = tf.data.Dataset.from_tensors([100, 101, 102]).unbatch()
# replaced letters with numbers to make tensor types match

# define ratio
r1 = 3
r2 = 1

b1 = d1.batch(r1)
b2 = d2.batch(r2)

zipped = tf.data.Dataset.zip((b1, b2)).map(lambda x, y: tf.concat((x, y), axis=0))
result = zipped.unbatch()

输出:

In [9]: list(result)                                                                                                                  
Out[9]: 
[<tf.Tensor: id=224, shape=(), dtype=int32, numpy=0>,
 <tf.Tensor: id=225, shape=(), dtype=int32, numpy=1>,
 <tf.Tensor: id=226, shape=(), dtype=int32, numpy=2>,
 <tf.Tensor: id=227, shape=(), dtype=int32, numpy=100>,
 <tf.Tensor: id=228, shape=(), dtype=int32, numpy=3>,
 <tf.Tensor: id=229, shape=(), dtype=int32, numpy=4>,
 <tf.Tensor: id=230, shape=(), dtype=int32, numpy=5>,
 <tf.Tensor: id=231, shape=(), dtype=int32, numpy=101>,
 <tf.Tensor: id=232, shape=(), dtype=int32, numpy=6>,
 <tf.Tensor: id=233, shape=(), dtype=int32, numpy=7>,
 <tf.Tensor: id=234, shape=(), dtype=int32, numpy=8>,
 <tf.Tensor: id=235, shape=(), dtype=int32, numpy=102>]

注意:此解决方案可能会删除 d1d2 末尾的一些元素 - 它们的长度必须调整为比例。

关于python - 如何将两个 tf.data.Dataset 合并为一个具有已知比率的交替元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58572535/

相关文章:

tensorflow - 为可变大小的输入和固定大小的输入创建 TensorFlow 占位符有什么缺点吗?

python - 如何将 mlflow 与 tensorflow object detection api 集成

python - 从图像本地目录创建tensorflow数据集

Javascript/jQuery - 重新排序 div

python - 我的代码从迭代器获取数据多少次?

python - 从命令行强制 TensorFlow-GPU 使用 CPU

Python/Pandas For 循环时间序列

python - 在 Python 3.4 中按索引访问 16.0 Pandas 数据框中的行时出现 keyerror

python - 无法安装安装了 64 位版本的 Anaconda 的 boost 库

python - tensorflow 中的单神经元前馈网络