我想为我的生成器提供参数以与tf.data.Dataset.from_generator()
结合使用。例如:
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
该生成器生成 lo
和 hi
之间的 float 。但请注意,在创建数据集时,这些参数永远不会传递给此生成器。
tf.data.Dataset.from_generator(generator, tf.float64)
这是因为 tf.data.Dataset.from_generator()
的生成器参数不应采用任何参数。
有什么解决办法吗?
最佳答案
我找到了一个基于称为部分应用函数的函数式编程概念的解决方案。总结一下:
a PAF is a function that takes a function with multiple parameters and returns a function with fewer parameters.
我的做法如下:
from functools import partial
import tensorflow as tf
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
def get_generator(lo, hi):
return partial(generator, lo, hi)
tf.data.Dataset(get_generator(lo, hi), tf.float64)
get_generator(lo, hi)
函数返回生成器的部分应用函数,该函数修复 lo
和 hi
参数的值,这实际上是 tf.data.Dataset.from_generator() 所需的无参数生成器。
关于python - 使用 tf.data.Dataset.from_generator() 时参数化生成器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51655339/