python - Tensorflow 中线程安全图的使用

标签 python multithreading tensorflow flask keras

我有一个 flask 应用程序,它首先加载一个 keras 模型,然后执行预测功能。根据this answer我将 tensorflow 图保存在一个全局变量中,并将相同的图用于预测函数。

def load_model():
    load_model_from_file()
    global graph
    graph = tf.get_default_graph()

def predict():
    with graph.as_default():
        tagger = Tagger(self.model, preprocessor=self.p)
        return tagger.analyze(words)

@app.route('/predict', methods=['GET'])
def load_and_predict():
    load_model()
    predict()

但是,每当向服务器发送多个请求时,这就会导致问题。如何使这段代码线程安全,或者更具体地说,如何在多线程环境中正确使用 tensorflow 图?

最佳答案

通常,在 tensorflow 中使用线程时应该使用 session 。

intra_parallel_thread_tf = 1
inter_parallel_thread_tf = 1

session_conf = tf.ConfigProto(intra_op_parallelism_threads=intra_parallel_thread_tf,
                          inter_op_parallelism_threads=inter_parallel_thread_tf)

tf.Session(graph=tf.get_default_graph(), config=session_conf)
GRAPH = tf.get_default_graph()

但这很笼统。这也取决于你得到的错误。

关于python - Tensorflow 中线程安全图的使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54129714/

相关文章:

c++ - std::vector保留并调整NUMA位置的大小

python - 是否可以用现有图形中的常量替换占位符?

python - 使用基本的低级 TensorFlow 训练循环训练 tf.keras 模型不起作用

python - Java 网关进程在向驱动程序发送其端口号之前退出

python - 从 CNN 层获取过滤器值

python - 如何查看 2 个日期产生的收入超过 1 个日期的总收入的实例有多少?

python - 无法在 Python 中比较字符串

C++ 线程没有像我预期的那样工作?

android - 套接字、读和写线程

machine-learning - 近端梯度下降的 l1_regularization_strength 和 l2_regularization_strength 的定义