我有一个 flask 应用程序,它首先加载一个 keras 模型,然后执行预测功能。根据this answer我将 tensorflow 图保存在一个全局变量中,并将相同的图用于预测函数。
def load_model():
load_model_from_file()
global graph
graph = tf.get_default_graph()
def predict():
with graph.as_default():
tagger = Tagger(self.model, preprocessor=self.p)
return tagger.analyze(words)
@app.route('/predict', methods=['GET'])
def load_and_predict():
load_model()
predict()
但是,每当向服务器发送多个请求时,这就会导致问题。如何使这段代码线程安全,或者更具体地说,如何在多线程环境中正确使用 tensorflow 图?
最佳答案
通常,在 tensorflow 中使用线程时应该使用 session 。
intra_parallel_thread_tf = 1
inter_parallel_thread_tf = 1
session_conf = tf.ConfigProto(intra_op_parallelism_threads=intra_parallel_thread_tf,
inter_op_parallelism_threads=inter_parallel_thread_tf)
tf.Session(graph=tf.get_default_graph(), config=session_conf)
GRAPH = tf.get_default_graph()
但这很笼统。这也取决于你得到的错误。
关于python - Tensorflow 中线程安全图的使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54129714/