Python Beam 无法 pickle/dill 大型 Tensorflow 模型

标签 python tensorflow google-cloud-dataflow tensorflow-serving

我们正在尝试内联提供图像处理模型(在 Tensorflow 中),以便我们不必出于速度目的而对 REST 服务或 Cloud-ML/ML-Engine 模型进行外部调用。

我们不想在每次推理时都尝试加载模型,而是想测试是否可以为beam.DoFn对象的每个实例将模型加载到内存中,这样我们就可以减少加载和服务时间对于模型。

例如

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function

    import tensorflow as tf
    import numpy as np


    class InferenceFn(object):

      def __init__(self, model_full_path,):
        super(InferenceFn, self).__init__()
        self.model_full_path = model_full_path
        self.graph = None
        self.create_graph()


      def create_graph(self):
        if not tf.gfile.FastGFile(self.model_full_path):
          self.download_model_file()
        with tf.Graph().as_default() as graph:
          with tf.gfile.FastGFile(self.model_full_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(graph_def, name='')
        self.graph = graph

当它不是 beam.DoFn 而只是常规类时,它可以在本地正常运行,但是当它转换为 DoFn 并且我尝试使用 Cloud Dataflow 远程执行它时,作业会失败,因为在序列化/酸洗期间,我想相信它试图序列化整个模型

例如 Example of Error

有没有办法规避这个问题或阻止 python/dataflow 尝试序列化模型?

最佳答案

是的——将模型存储为 DoFn 上的字段需要将其序列化,以便将该代码传递给每个工作人员。您应该查看以下内容:

  1. 安排让每个工作人员都可以使用模型文件。 Python dependencies document 中对数据流进行了描述。 .
  2. 在您的 DoFn 中实现 start_bundle 方法并让它读取文件并将其存储在本地线程中。

这可确保文件的内容不会在本地计算机上读取并腌制,而是使每个工作人员都可以使用该文件,然后读入。

关于Python Beam 无法 pickle/dill 大型 Tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45171264/

相关文章:

python - Theano/Pylearn2。如何并行训练?

python - 如果我在任务中发送 http 请求,为什么我的 Airflow 会挂起?

python - tensorflow 的线性回归得到明显的均方误差

java - 如何检查为什么作业在 Google Dataflow 上被杀死(可能 OOM)

python - 在 apache beam 的窗口中聚合数据

design-patterns - 在 apache beam 中调用外部 API 的更好方法

Python crontab 不工作

python - Dart - Base64 字符串不等于 python

machine-learning - 使用 tensorflow 后端在 keras 中定义 pinball 损失函数

python - 在 Tensorflow 中计算权重更新比率