我正在使用 Flask 运行 Web 服务器,当我尝试使用 vgg16 时出现错误,vgg16 是 keras 的预训练 VGG16 模型的全局变量。我不知道为什么会出现这个错误,也不知道它是否与 Tensorflow 后端有关。
这是我的代码:
vgg16 = VGG16(weights='imagenet', include_top=True)
def getVGG16Prediction(img_path):
global vgg16
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
pred = vgg16.predict(x)
return x, sort(decode_predictions(pred, top=3)[0])
@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
uploaded_files = request.files.getlist("file[]")
for file in uploaded_files:
path = os.path.join(STATIC_PATH, file.filename)
pInput, result = getVGG16Prediction(path)
这是完整的错误:
非常感谢任何评论或建议。谢谢你。
最佳答案
看看avital
对 this github issue 的回答.在这里引用相关部分:
Right after loading or constructing your model, save the TensorFlow graph:
graph = tf.get_default_graph()
In the other thread (or perhaps in an asynchronous event handler), do:
global graph with graph.as_default(): (... do inference here ...)
我稍微修改了一下,并将图表存储在我的应用程序的配置对象中,而不是使其成为全局对象。
TensorFlow documentation为
get_default_graph
解释为什么这是必要的:NOTE: The default graph is a property of the current thread. If you create a new thread, and wish to use the default graph in that thread, you must explicitly add a with g.as_default(): in that thread's function.
关于tensorflow - ValueError : Tensor Tensor(. ..) 不是该图的元素。使用全局变量 keras 模型时,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42013138/