在生产环境中,我有来自 N 个生产者的数据,这些数据必须通过网络。我在 parallelising tf.data.Dataset.from_generator 上找到了这条评论这真的描述了我想要的东西。
def generator(n):
# returns n-th generator function
def dataset(n):
return tf.data.Dataset.from_generator(generator(n))
ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
# where N is the number of generators you use
然而,generator(n) 函数应该是什么样子的。因为当我运行这个示例时
def generator(n):
"""Returns the n-th generator function (for consumer n)
"""
consumer = self.consumers[n]
def gen():
for item in consumer:
yield item
return gen
使用 self.consumers 一个 Python 列表然后我会得到错误:
TypeError: list indices must be integers or slices, not Tensor
最佳答案
实现几乎是正确的,但是您收到错误,因为 n
参数在 dataset(n)
是“象征”tf.Tensor
,而不是可用于在 self.consumers
中查找使用者的实际值.
幸运的是,有一个解决方法,它涉及传递 n
通过可选 args
论据 tf.data.Dataset.from_generator()
:
def dataset(n):
return tf.data.Dataset.from_generator(generator, args=(n,))
在封面下,
from_generator()
插入一些代码来转换 n
在每次调用 generator
之前转换为 Python 整数.
关于Tensorflow 数据集 API : parallelising tf. data.Dataset.from_generator with parallel_interleave,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50295527/