python - 张量不是该图的元素;部署 Keras 模型

标签 python tensorflow flask keras

我正在部署 keras 模型并通过 Flask api 将测试数据发送到模型。我有两个文件:

第一个:我的 Flask 应用程序:

# Let's startup the Flask application
app = Flask(__name__)

# Model reload from jSON:
print('Load model...')
json_file = open('models/model_temp.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
keras_model_loaded = model_from_json(loaded_model_json)
print('Model loaded...')

# Weights reloaded from .h5 inside the model
print('Load weights...')
keras_model_loaded.load_weights("models/Model_temp.h5")
print('Weights loaded...')

# URL that we'll use to make predictions using get and post
@app.route('/predict',methods=['GET','POST'])
def predict():
    data = request.get_json(force=True)
    predict_request = [data["month"],data["day"],data["hour"]] 
    predict_request = np.array(predict_request)
    predict_request = predict_request.reshape(1,-1)
    y_hat = keras_model_loaded.predict(predict_request, batch_size=1, verbose=1)
    return jsonify({'prediction': str(y_hat)}) 

if __name__ == "__main__":
    # Choose the port
    port = int(os.environ.get('PORT', 9000))
    # Run locally
    app.run(host='127.0.0.1', port=port)

第二:我用来将 json 数据发送到 api 端点的文件:

response = rq.get('api url has been removed')
data=response.json()
currentDT = datetime.datetime.now()
Month = currentDT.month
Day = currentDT.day
Hour = currentDT.hour

url= "http://127.0.0.1:9000/predict"
post_data = json.dumps({'month': month, 'day': day, 'hour': hour,})
r = rq.post(url,post_data)

我从 Flask 得到了关于 Tensorflow 的回复:

ValueError:Tensor Tensor("dense_6/BiasAdd:0", shape=(?, 1), dtype=float32) 不是此图的元素。

我的 keras 模型是一个简单的 6 密集层模型,训练时没有错误。

有什么想法吗?

最佳答案

Flask 使用多线程。您遇到的问题是因为 tensorflow 模型未在同一线程中加载和使用。一种解决方法是强制 Tensorflow 使用全局默认图。

加载模型后添加此内容

global graph
graph = tf.get_default_graph() 

在你的预测中

with graph.as_default():
    y_hat = keras_model_loaded.predict(predict_request, batch_size=1, verbose=1)

关于python - 张量不是该图的元素;部署 Keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51127344/

相关文章:

python - 从格式化字符串打印 unicode

python - 如何为嵌套列表中的特定字符串创建子列表

python - 属性错误 : ‘LSTMStateTuple’ object has no attribute ‘get_shape’ while building a Seq2Seq Model using Tensorflow

python - 从 GKE 调用 tf.io.gfile 方法时经常出现 DNS 失败 : "Couldn' t resolve host 'www.googleapis.com' "

mongodb - 在 mongodb 聚合管道中展开操作后查找阶段不工作

python - 动态长度Django模型字段

python - 在 GLPK 中将负载分散到不同的等成本变量上

python - 随机生成张量的 Tensorflow 转置

python - flask 测试 : issues with session management with unittests

python - 验证参数时 webargs 出现异常