Tensorflow 数据集 API : parallelising tf. data.Dataset.from_generator with parallel_interleave

标签 tensorflow tensorflow-datasets

在生产环境中,我有来自 N 个生产者的数据,这些数据必须通过网络。我在 parallelising tf.data.Dataset.from_generator 上找到了这条评论这真的描述了我想要的东西。

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use

然而,generator(n) 函数应该是什么样子的。因为当我运行这个示例时
 def generator(n):
        """Returns the n-th generator function (for consumer n)
        """
        consumer = self.consumers[n]

        def gen():
            for item in consumer:
                yield item

        return gen

使用 self.consumers 一个 Python 列表然后我会得到错误:

TypeError: list indices must be integers or slices, not Tensor

最佳答案

实现几乎是正确的,但是您收到错误,因为 n参数在 dataset(n)是“象征”tf.Tensor ,而不是可用于在 self.consumers 中查找使用者的实际值.

幸运的是,有一个解决方法,它涉及传递 n通过可选 args论据 tf.data.Dataset.from_generator() :

def dataset(n):
  return tf.data.Dataset.from_generator(generator, args=(n,))

在封面下,from_generator()插入一些代码来转换 n在每次调用 generator 之前转换为 Python 整数.

关于Tensorflow 数据集 API : parallelising tf. data.Dataset.from_generator with parallel_interleave,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50295527/

相关文章:

python - 这些模型是否等效?

c++ - 在 Tensorflow C++ API 中初始化变量

tensorflow - 如何在 Tensorflow.js 中改变张量的值?

python - tf.data API无法打印所有批处理

python-3.x - tensorflow 中 tf.data.Dataset 的填充

tensorflow - 与线程/队列相比,tf.data.Dataset 输入管道提供了糟糕的结果

tensorflow - 如何将经过训练的 TF1 protobuf 模型加载到 TF2 中?

tensorflow - 参差不齐的张量作为 LSTM 的输入

python - Tensorflow 在每次调用带有最终图的 session.run() 时泄漏内存

python - 在 tensorflow 中使用迭代器生成特征和标签