python - 停止 TensorFlow 数据集 `from_generator` 的正确方法?

标签 python tensorflow tensorflow-datasets

我想使用通过 from_generator 构建的 TensorFlow 数据集来访问格式化文件。大多数一切正常,除了我不知道如何在生成器用完数据时停止数据集迭代器(当您超出范围时,生成器只会永远返回空列表)。

我的实际代码非常复杂,但我可以用这个短程序来模拟情况:

import tensorflow as tf

def make_batch_generator_fn(batch_size=10, dset_size=100):
    feats, targs = range(dset_size), range(1, dset_size + 1)

    def batch_generator_fn():
        start_idx, stop_idx = 0, batch_size
        while True:
            # if stop_idx > dset_size: --- stop action?
            yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
            start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

    return batch_generator_fn

def test(batch_size=10):
    dgen = make_batch_generator_fn(batch_size)
    features_shape, targets_shape = [None], [None]
    ds = tf.data.Dataset.from_generator(
        dgen, (tf.int32, tf.int32),
        (tf.TensorShape(features_shape), tf.TensorShape(targets_shape))
    )
    feats, targs = ds.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        counter = 0
        try:
            while True:
                f, t = sess.run([feats, targs])
                print(f, t)
                counter += 1
                if counter > 15:
                    break
        except tf.errors.OutOfRangeError:
            print('end of dataset at counter = {}'.format(counter))

if __name__ == '__main__':
    test()

如果我提前知道记录数,我可以调整批处理数,但我并不总是知道。我尝试将一些代码放在上面的代码片段中,其中我有一个注释行,如 stop action?。特别是,我尝试引发 IndexError,但 TensorFlow 不喜欢这样,即使我在我的执行代码中显式 catch 它也是如此。我还尝试引发 tf.errors.OutOfRangeError,但我不确定如何实例化它:构造函数需要三个参数 - 'node_def'、'op' 和 'message',而我我不太确定一般情况下“node_def”和“op”要使用什么。

如果您对此问题有任何想法或意见,我将不胜感激。谢谢!

最佳答案

满足停止条件时返回:

def make_batch_generator_fn(batch_size=10, dset_size=100):
    feats, targs = range(dset_size), range(1, dset_size + 1)

    def batch_generator_fn():
        start_idx, stop_idx = 0, batch_size
        while True:
            if stop_idx > dset_size:
                return
            else:
                yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
                start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

    return batch_generator_fn

这符合 Python 3 documentation: 中指定的行为

In a generator function, the return statement indicates that the generator is done and will cause StopIteration to be raised. The returned value (if any) is used as an argument to construct StopIteration and becomes the StopIteration.value attribute.

关于python - 停止 TensorFlow 数据集 `from_generator` 的正确方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50275833/

相关文章:

tensorflow - 如何使用 Dataset API 最大化 tensorflow GPU 训练中的 CPU 利用率?

python - 为什么我能够在 Python 中实例化我的抽象基类?

python - Pandas:根据 int 值更改年份

包含奇怪对象数组的 Python 对象

python - 替换 "tf.gather_nd"

python - 自定义丢失问题 : inputs to eager execution function cannot be keras symbolic tensors but found

python - FailedPreconditionError : Attempting to use uninitialized in Tensorflow

tensorflow - 何时在 Tensorflow Estimator 中使用迭代器

python - 重新初始化数据集后损失回到起始值

python - 在重度嵌套的 defaultdict 中使用更 Pythonic 的方式来计算事物