我已经使用以下代码部署了一个 sagemaker 端点:
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role, Session
sess = Session()
role = get_execution_role()
model = PyTorchModel(model_data=my_trained_model_location,
role=role,
sagemaker_session=sess,
framework_version='1.5.0',
entry_point='inference.py',
source_dir='.')
predictor = model.deploy(initial_instance_count=1,
instance_type='ml.m4.xlarge',
endpoint_name='my_endpoint')
如果我运行:
import numpy as np
pseudo_data = [np.random.randn(1, 300), np.random.randn(6, 300), np.random.randn(3, 300), np.random.randn(7, 300), np.random.randn(5, 300)] # input data is a list of 2D numpy arrays with variable first dimension and fixed second dimension
result = predictor.predict(pseudo_data)
我可以生成没有错误的结果。但是,如果我想调用端点并通过运行进行预测:
from sagemaker.predictor import RealTimePredictor
predictor = RealTimePredictor(endpoint='my_endpoint')
result = predictor.predict(pseudo_data)
我会得到一个错误:
Traceback (most recent call last):
File "default_local.py", line 77, in <module>
score = predictor.predict(input_data)
File "/home/biggytruck/.local/lib/python3.6/site-packages/sagemaker/predictor.py", line 113, in predict
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
File "/home/biggytruck/.local/lib/python3.6/site-packages/botocore/client.py", line 316, in _api_call
return self._make_api_call(operation_name, kwargs)
File "/home/biggytruck/.local/lib/python3.6/site-packages/botocore/client.py", line 608, in _make_api_call
api_params, operation_model, context=request_context)
File "/home/biggytruck/.local/lib/python3.6/site-packages/botocore/client.py", line 656, in _convert_to_request_dict
api_params, operation_model)
File "/home/biggytruck/.local/lib/python3.6/site-packages/botocore/validate.py", line 297, in serialize_to_request
raise ParamValidationError(report=report.generate_report())
botocore.exceptions.ParamValidationError: Parameter validation failed:
Invalid type for parameter Body
根据我的理解,发生错误是因为我没有将 inference.py
作为入口点文件传递,这是处理输入所必需的,因为它不是 Sagemaker 支持的标准格式.但是,sagemaker.predictor.RealTimePredictor
不允许我定义入口点文件。我该如何解决这个问题?
最佳答案
您看到的错误是从客户端 SageMaker Python SDK 库引发的,而不是您发布的远程端点。
这是 data
的文档参数(在你的例子中,这是 pseudo_data
)
data (object) – Input data for which you want the model to provide inference. If a serializer was specified when creating the RealTimePredictor, the result of the serializer is sent as input data. Otherwise the data must be sequence of bytes, and the predict method then sends the bytes in the request body as is.
我的猜测是pseudo_data
不是 SageMaker Python SDK 期望的类型,它是 bytes
的序列.
关于amazon-web-services - 使用自定义推理脚本调用 sagemaker 端点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62744016/