django - 凯拉斯预测 celery 任务不归队

标签 django tensorflow redis celery keras

同步调用时,遵循Keras函数(预测)工作

pred = model.predict(x)

但是当从异步任务队列(Celery)中调用时,它不起作用。
Keras预测函数在异步调用时不会返回任何输出。

堆栈是:Django,Celery,Redis,Keras,TensorFlow

最佳答案

我碰到了这个完全相同的问题,而那家伙真是个兔子洞。想要在这里发布我的解决方案,因为这可能会节省某人一天的工作:

TensorFlow特定于线程的数据结构

在TensorFlow中,当您调用model.predict(或keras.models.load_modelkeras.backend.clear_session或几乎任何其他与TensorFlow后端交互的函数)时,有两种关键的数据结构在幕后起作用:

  • TensorFlow graph,表示您的Keras模型的结构
  • TensorFlow session,它是当前图形与TensorFlow运行时
  • 之间的连接

    在文档中不加挖掘就没有明确明确的一点是, session 和图形都是当前线程的属性。请参阅API文档herehere

    在不同线程中使用TensorFlow模型

    很自然,只想加载一次模型,然后在以后多次调用.predict():

    from keras.models import load_model
    
    MY_MODEL = load_model('path/to/model/file')
    
    def some_worker_function(inputs):
        return MY_MODEL.predict(inputs)
    
    

    在像Celery这样的Web服务器或工作程序池上下文中,这意味着您在导入包含load_model行的模块时将加载模型,然后另一个线程将执行some_worker_function,对包含Keras模型的全局变量运行预测。但是,尝试在装入不同线程的模型上运行预测会产生“张量不是该图的元素”错误。感谢涉及此主题的几篇SO帖子,例如ValueError: Tensor Tensor(...) is not an element of this graph. When using global variable keras model。为了使它起作用,您需要保留所使用的TensorFlow图-如我们之前所见,该图是当前线程的属性。更新后的代码如下所示:

    from keras.models import load_model
    import tensorflow as tf
    
    MY_MODEL = load_model('path/to/model/file')
    MY_GRAPH = tf.get_default_graph()
    
    def some_worker_function(inputs):
        with MY_GRAPH.as_default():
            return MY_MODEL.predict(inputs)
    

    这里有些令人惊讶的变化是:如果您使用Thread,上面的代码就足够了,但是如果您使用Process es,则可以无限期地挂起。 并且默认情况下,Celery使用进程来管理其所有工作池。因此,此时,Celery仍然无法正常工作。

    为什么这仅适用于Thread

    在Python中,Thread与父进程共享相同的全局执行上下文。从Python _thread docs:

    This module provides low-level primitives for working with multiple threads (also called light-weight processes or tasks) — multiple threads of control sharing their global data space.



    因为线程不是实际的单独进程,所以它们使用相同的python解释器,因此受到臭名昭著的全局中断器(GIL)的约束。对于此调查而言,也许更重要的是,它们与父级共享全局数据空间。

    与此相反,Process es是程序产生的实际新进程。这意味着:
  • 新的Python解释器实例(且没有GIL)
  • 全局地址空间是重复的

  • 注意这里的区别。虽然Thread可以访问共享的单个全局Session变量(内部存储在Keras的tensorflow_backend模块中),但是Process es具有Session变量的重复项。

    我对这个问题的最好理解是,Session变量应该表示客户机(进程)与TensorFlow运行时之间的唯一连接,但是由于在派生过程中被复制,因此该连接信息没有得到适当的调整。当尝试使用在其他进程中创建的 session 时,这会导致TensorFlow挂起。 如果有人对TensorFlow的工作原理有更深入的了解,我很想听听!

    解决方案/解决方法

    我调整了Celery,以便它使用Thread而不是Process进行合并。这种方法有一些缺点(请参见上面的GIL注释),但这使我们只能加载一次模型。由于TensorFlow运行时会最大化所有CPU内核,因此我们实际上并没有CPU限制(因为它不是用Python编写的,因此可以避开GIL)。您必须为Celery提供一个单独的库才能进行基于线程的池化。该文档建议了两个选项: gevent eventlet 。然后,您将选择的库传递给工作程序via the --pool command line argument

    或者,似乎其他的Keras后端(如Theano)(如您已经发现@ pX0r)没有此问题。这是有道理的,因为这些问题与TensorFlow实现细节紧密相关。我个人尚未尝试过Theano,所以您的里程可能会有所不同。

    我知道这个问题是在不久前发布的,但是问题仍然存在,因此希望可以对您有所帮助!

    关于django - 凯拉斯预测 celery 任务不归队,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45459205/

    相关文章:

    django - Django 能否自动创建相关的一对一模型?

    tensorflow - 带条件的 LSTM

    java - 如何使用 RedisTemplate 访问由 spring redis session 存储的散列 key ?

    javascript - Django 使用前端 python 类中的函数?

    python - Django 导入 csv HTML

    python - Django REST 框架 : nested relationship: non_field_errors

    tensorflow - DeepMind 的 Sonnet 能提供 Keras 不能提供的什么?

    java - tensorflow java模型推理将获取的张量转换为字符串?

    django - celery 没有连接到redis服务器

    redis - 主从和发布订阅连接