python - TensorFlow实验: how to avoid loading all data in memory with input_fn?

标签 python tensorflow

我正在努力将我的(困惑的)代码从tensorflow核心传递到Estimator范例,尤其是使用Experiments - 与learn_runner.run >。但实际上我在向神经网络提供数据时遇到了问题。

我想要实现的目标实际上非常接近 TensorFlow 和 tf.TextLineReader 的所有示例所完成的目标,例如https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/customestimator/trainer/model.py#L297 ,尽管我不是从磁盘上的文件加载数据,而是通过网络服务加载数据。

根据我的理解(并查看 tensorflow.python.estimator._train_model() 的代码),input_fn 仅调用一次,而不是在每次迭代时调用。我可以轻松加载所有数据,然后执行以下操作:

def input_fn():
    data = # all data in memory
    batch = tf.train.input_producer(tf.constant(data))
    return batch.dequeue_many(batch_size)

但是这是不可持续的,因为我的数据不适合内存。我正在尝试做类似的事情:

1. load first piece of data (say N lines)
2. consume it by batches in a queue just like the input_fn above
2'. feed this queue asynchronously with new data when it's almost empty

我知道如何在“纯”tf 中做到这一点,例如How to prefetch data using a custom python function in tensorflowTensorflow: custom data load + asynchronous computation但我发现很难将其转换为 Experiment 范例,因为我无法访问 session 来自行加载内容,也无法访问图表来在内部追加操作。

编辑

我设法使用tf.py_func()来做到这一点,例如:

class Reader(object):
     # a Python object that can load data and have some intelligence, not related to TF, initialized with batch_sized

    def read_up_to(self):
        """Reads up to batch_size elements loaded in Python"""

def input_fn():
    reader = Reader() # instantiated once
    return tf.py_func(reader.read_up_to, inp=[], Tout=...)

我工作得很好,尽管速度有点慢(正如预期的那样,有一种从 C++ 执行到 Python 的方法,会引入大约 50% 的延迟)。我正在尝试通过将读取器异步读取的 Python 数据放入特定的 TensorFlow 队列来解决此问题,这样就可以在不将数据从 Python 传递到 C++ 的情况下完成加载(就像上面的两个链接一样)。

最佳答案

我有一个similar issue我通过使用 SessionRunHook 找到了修复程序。该 Hook (还有其他 Hook )允许您在创建 session 后立即初始化操作。

关于python - TensorFlow实验: how to avoid loading all data in memory with input_fn?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45191474/

相关文章:

Python如何将类函数从实例传递到另一个类

python - Flask-Login @login-required 装饰器在 session 过期后不会重定向到登录页面

python - 在列表中查找元素索引的最快方法

Python (numpy) 子网格的 2D 和 3D 数组平均

python - TensorFlow.Data.Dataset 和 DatasetV1Adapter 一样吗?

python - Keras:如何存储每个纪元后的历史记录?

python - 如何使用vtk在python中绘制鼠标可旋转点云

python - tensorflow 在损失函数中使用输入

tensorflow - 将现有模型以前动态的占位符尺寸设为静态

ubuntu - 如何设置 TensorFlow?