我有一个在单台机器上训练的模型,没有使用 Estimator,我希望在 Google 云 AI 平台(ML 引擎)上提供最终训练的模型。我使用 SavedModelBuilder 将卡住图导出为 SavedModel并将其部署到AI平台上。它适用于小型输入图像,但为了能够接受大型输入图像进行在线预测,我需要将其更改为接受 b64 编码字符串 ({'image_bytes': {'b64': base64.b64encode( jpeg_data).decode()}}
) 如果使用估算器,则通过 serving_input_fn
将其转换为所需的张量。
如果我不使用估算器,我有哪些选择?如果我有一个卡住的图表或从 SavedModelBuilder 创建的 SavedModel,有没有办法在导出/保存时拥有类似于估算器的 serving_input_fn
的东西?
这是我用于导出的代码:
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
export_dir = 'serving_model/'
graph_pb = 'model.pb'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name("image_bytes:0")
out_f1 = g.get_tensor_by_name("feature_1:0")
out_f2 = g.get_tensor_by_name("feature_2:0")
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"image_bytes": inp}, {"f1": out_f1, "f2": out_f2})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
strip_default_attrs=True,
signature_def_map=sigs)
builder.save()
最佳答案
使用@tf.function指定服务签名。下面是一个调用 Keras 的示例:
class ExportModel(tf.keras.Model):
def __init__(self, model):
super().__init__(self)
self.model = model
@tf.function(input_signature=[
tf.TensorSpec([None,], dtype='int32', name='a'),
tf.TensorSpec([None,], dtype='int32', name='b')
])
def serving_fn(self, a, b):
return {
'pred' : self.model({'a': a, 'b': b}) #, steps=1)
}
def save(self, export_path):
sigs = {
'serving_default' : self.serving_fn
}
tf.keras.backend.set_learning_phase(0) # inference only
tf.saved_model.save(self, export_path, signatures=sigs)
sm = ExportModel(model)
sm.save(EXPORT_PATH)
关于python - 如何在不使用估计器的情况下为训练的 Tensorflow 模型编写服务输入函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57387086/