Pytorchで独自のアルゴリズムで学習した骨格検知モデルをSagemaker上でリアルタイム推論してみた。
せーのでございます。
Sagemakerによるモデル学習のコツはだいぶ掴んできたものの、いままで学習したモデルはエッジにデプロイしていたのでSagemaker上の推論処理はほぼ試したことがありませんでした。
ということで今回はSagemakerでの推論処理にチャレンジしてみたいと思います。
Sagemakerで扱えるモデルのタイプは3種類
Sagemakerでは学習、推論ともに大きく分けると3つのタイプに分かれます。
- Built-in アルゴリズム: Sagemakerにて用意したアルゴリズムを使用して学習、推論を行う。一番簡単。
- Sagemaker コンテナ: 学習や推論に使うコンテナをSagemakerが用意し、そこに学習コード、推論コードを注入して使う。Tensorflow, Pytorch, MXNet, Chainerなどメジャーなフレームワークのコンテナが用意されている。
- 独自コンテナ : 学習や推論に使うコードを含んだコンテナを独自に構築し、ECRにプッシュ。そこからインスタンスを起動して学習、推論を行う。起動関数名やAPI名などが決まっているため、やや面倒。
今回は骨格検知のスタンダードでもあるOpenposeを変更して軽くしたアルゴリズム「Real-time 2D Multi-Person Pose Estimation on CPU: Lightweight OpenPose」をSagemakerコンテナに注入して学習したモデルを使って推論処理をしてみたいと思います。
やってみた
ローカルで試してみる
まずは作ったモデルをローカル上で動くかどうか試してみたいと思います。
ちょうど動かす用にdemo.pyという推論用のコードがあるので、それに独自に学習したモデルをセットして推論してみます。
$ python demo.py --checkpoint-path data/checkpoint_iter_84600.pth --image images/test.jpg --cpu --track 0
無事、動きました。
ただ今回は画像で返すのではなく、キーポイントの座標と信頼度で返すため、コードを少しいじります。
pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0]) pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1]) #confidencdeを3番目に入れておく(追加) res_confidence[kpt_id] = all_keypoints[int(pose_entries[n][kpt_id]), 2] ... ... #confidenceの値をキーポイントを保持しているクラスに追加(変更) pose = Pose(pose_keypoints, pose_entries[n][18], res_confidence) current_poses.append(pose) ... ... for pose in current_poses: cv2.rectangle(img, (pose.bbox[0], pose.bbox[1]), (pose.bbox[0] + pose.bbox[2], pose.bbox[1] + pose.bbox[3]), (0, 255, 0)) #JSON形式でキーポイント座標と信頼度を表示(追加) persons = [] for pose_item in current_poses: person_keypoints = [] for idx, keypoint in enumerate(pose_item.keypoints): if keypoint[0] == -1: continue person_keypoint = [] person_keypoint.append(int(keypoint[0])) person_keypoint.append(int(keypoint[1])) person_keypoint.append(float(pose_item.res_confidence[idx])) person_keypoints.append(person_keypoint) persons.append(person_keypoints) print(json.dumps(persons, indent=2))
... ... def __init__(self, keypoints, confidence, res_confidence): super().__init__() self.keypoints = keypoints self.confidence = confidence #保管用にプロパティを追加 self.res_confidence = res_confidence self.bbox = Pose.get_bbox(self.keypoints) self.id = None self.filters = [[OneEuroFilter(), OneEuroFilter()] for _ in range(Pose.num_kpts)]
改めて流してみます。
$ python demo.py --checkpoint-path data/checkpoint_iter_84600_V10.pth --image images/test.jpg --cpu --track 0 inference... [ [ [ 585, 176, 0.7670338749885559 ], [ 562, 187, 0.9505833983421326 ], [ 588, 195, 0.909247100353241 ], [ 581, 172, 0.287502646446228 ], [ 570, 153, 0.582249641418457 ], [ 532, 176, 0.9301580190658569 ], [ 525, 131, 0.8460273742675781 ], [ 532, 93, 0.6852182149887085 ] ], [ [ 270, 165, 0.898603618144989 ], [ 285, 183, 0.9406200647354126 ], [ 311, 180, 0.8064700961112976 ], [ 330, 146, 0.8526690602302551 ], [ 315, 112, 0.9707748889923096 ], [ 258, 187, 0.8102858066558838 ], [ 255, 172, 0.5221325159072876 ], [ 258, 172, 0.34177863597869873 ] ], [ [ 382, 138, 0.962169885635376 ], [ 390, 180, 0.9582849144935608 ], [ 438, 176, 0.8331146836280823 ], [ 450, 198, 0.7568672895431519 ], [ 427, 202, 0.43545061349868774 ], [ 345, 187, 0.9023403525352478 ], [ 348, 210, 0.34671860933303833 ], [ 356, 198, 0.18269583582878113 ] ] ]
これで各キーポイントの座標と信頼度が取れました。今度はこれをSagemakerに移植します。
Sagemakerへの移植
まずはSagemaker Notebooksにgitからソースごとアップします。これは素の状態で立ち上げたNotebooksのTerminalからgit cloneコマンドを打ってもよし、tarファイルなどにローカルのソースを固めてアップロード、その後Terminalから解凍コマンドを打ってもよし、です。私はinferenceというフォルダを作って、そこに固めてソースをアップ後、解凍しました。
ソースを整理するため推論処理にいらないソースは削除していますが、試すだけなら特にあっても問題ありません。
次に学習させたモデルを「model.tar.gz」というファイル名でtar圧縮し、S3に置いておきます。学習にSagemakerを使っているなら学習後のモデルはこのファイル名でS3に保管されているので何もしなくてもOKです。
Sagemaker用にソースコードを改変
Sagemakerの推論コードはモデルのロード、データの入力、推論、結果の出力の関数名が決まっています。
関数名 | 引数 | 戻り値 | 役割 |
---|---|---|---|
input_fn | request_body, request_content_type | input_object | データを指定したContent Typeで入力する |
model_fn | model_dir | model | モデルをロードする |
predict_fn | input_object, model | prediction | 入力値とモデルを元に推論する |
output_fn | prediction, response_content_type | output | 推論結果を指定したContent Typeで出力する |
またSagemakerの推論は基本API越しに行われるため、input_fnのrequest_bodyにはシリアライズされた値が入ります。
なのでinput_fnには入力値をデシリアライズ、前処理してpredict_fnに渡す処理、model_fnにはモデルをロードする処理、predict_fnには推論に当たる処理、output_fnはpredictionをJSON形式に変更した上でcontent typeを"text/json"にして出力する、という処理を入れればOKです。
Real-time 2D Multi-Person Pose Estimation on CPU: Lightweight OpenPoseに置いてそれらはここの部分になります。
- input_fn: 新たに作成。requestbodyをCV2などを使ってデシリアライズ、オブジェクト化する
- model_fn: アルゴリズムを作成、torch.load、load_stateしている部分
- predict_fn: run_demo()している部分
- output_fn: ローカルにて新たに作ったJSONを作成する部分
では実装します。
IMAGE_CONTENT_TYPE = 'application/x-image' def input_fn(request_body, content_type=IMAGE_CONTENT_TYPE): """入力データの形式変換.""" logger.info('START input_fn') logger.info(f'content_type: {content_type}') logger.info(f'type: {type(request_body)}') logger.info('Deserializing the input data.') # process an image uploaded to the endpoint if content_type == IMAGE_CONTENT_TYPE: input_data = cv2.imdecode(np.frombuffer(request_body, dtype='uint8'), cv2.IMREAD_UNCHANGED) logger.info(f'type: {type(input_data)}') else: logger.info("CONTENT ELSE") # TODO: content_typeに応じてデータ型変換 logger.error(f"content_type invalid: {content_type}") input_data = {"errors": [f"content_type invalid: {content_type}"]} logger.info('END input_fn') return input_data
def model_fn(model_dir): logger.info('START model_fn') net = PoseEstimationWithMobileNet() # モデルのパラメータ設定 with open(os.path.join(model_dir, 'checkpoint_iter_84600.pth'), 'rb') as f: # checkpoint = torch.load(f, map_location='cuda') checkpoint = torch.load(f, map_location='cpu') load_state(net, checkpoint) logger.info('END model_fn') return net
def predict_fn(input_data, net): """推論.""" logger.info('START predict_fn') logger.info(f'type: {type(input_data)}') # print(input_data) return run_demo(net, input_data, 256, True, 1, 1)
JSON_CONTENT_TYPE = 'application/json' def output_fn(prediction, accept=JSON_CONTENT_TYPE): """出力データの形式変換.""" logger.info('START output_fn') logger.info(f"accept: {accept}") persons = [] for pose_item in prediction: person_keypoints = [] for idx, keypoint in enumerate(pose_item.keypoints): if keypoint[0] == -1: continue person_keypoint = [] person_keypoint.append(int(keypoint[0])) person_keypoint.append(int(keypoint[1])) person_keypoint.append(float(pose_item.res_confidence[idx])) person_keypoints.append(person_keypoint) persons.append(person_keypoints) response = json.dumps({"results": persons}, indent=2) content_type = JSON_CONTENT_TYPE return response, content_type
次にこのモデルの操作用としてJupyter Notebooksを作成します。
まずSagemakerをインポートします。
import sagemaker sagemaker.__version__
次にロールを取得します。
from sagemaker import get_execution_role role = get_execution_role()
モデルを作成します。試しなので小さいインスタンスをしていしています。もし推論コードがGPUを使うものなら、ここのインスタンス選びもGPUのあるものを中心に選んでみてください。
from sagemaker.pytorch.model import PyTorchModel s3_path="s3://XXXXXXXXXXXXXXXXoutputs/sagemaker-pytorch-XXXXXXXXXXXXXXXX/output/model.tar.gz" pytorch_model = sagemaker.pytorch.model.PyTorchModel(model_data=s3_path, role=role, framework_version='1.4.0', py_version="py3", source_dir='inference', entry_point="inf.py") deploy_params = { 'endpoint_name' : 'sagemakerinftest', 'instance_type' : 'ml.t2.medium', 'initial_instance_count' : 1 } # デプロイ predictor = pytorch_model.deploy(**deploy_params)
テスト用の画像をアップし、読み込み、確認用に表示させてみます。
import matplotlib.pyplot as plt import numpy as np filename = 'test/test.jpg' # インライン表示 %matplotlib inline # Read image into memory payload = None with open(filename, 'rb') as f: payload = f.read() #画像の表示 plt.imshow(payload) plt.show()
最後はSagemaker Python SDKを使って推論処理を行います。
from sagemaker.serializers import IdentitySerializer from sagemaker.deserializers import JSONDeserializer predictor.serializer = IdentitySerializer() predictor.deserializer = JSONDeserializer() results = predictor.predict(payload, initial_args={'ContentType': 'application/x-image'}) print(results)
ここでのポイントは3点です。
- 画像のシリアライザー(IdentitySerializer())をpredictor に入れます。これはinput_fnに渡るrequest_bodyに適用されます
- 文字列(JSON)のデシリアライザー(JSONDeserializer())をpredictorに入れます。これはoutput_fnの出力であるoutputに使われます
- predict()時に流すcontent typeを指定します。これによりinput_fnのrequest_content_typeの種類が決まります
この3つを忘れないように実装します。
これで準備はOKです。Sagemaker Notebooksを上から順に流してみると
{'results': [[[585, 176, 0.7670751214027405], [562, 187, 0.9505834579467773], [588, 195, 0.9093048572540283], [581, 172, 0.28749969601631165], [570, 153, 0.5822606086730957], [532, 176, 0.9301615357398987], [525, 131, 0.8460462093353271], [532, 93, 0.6852066516876221]], [[270, 165, 0.8985937237739563], [285, 183, 0.9405729174613953], [311, 180, 0.8065033555030823], [330, 146, 0.8526697158813477], [315, 112, 0.9707766175270081], [258, 187, 0.8102609515190125], [255, 172, 0.5221396088600159], [258, 172, 0.3417767882347107]], [[382, 138, 0.962157130241394], [390, 180, 0.9582875967025757], [438, 176, 0.8330955505371094], [450, 198, 0.7568647265434265], [427, 202, 0.43543940782546997], [345, 187, 0.9023558497428894], [348, 210, 0.3467266857624054], [356, 198, 0.1827089786529541]]]}
このようにローカルの時と同じJSON形式でキーポイント座標と信頼度が返ってきます。
念の為、Sagemaker python SDKからだけではなく、API越しにも叩いてみます。
import boto3 import json client = boto3.client('sagemaker-runtime') # Request. response = client.invoke_endpoint( EndpointName='sagemakerinftest', ContentType='application/x-image', Accept='application/json', Body=bytearray(payload) ) response_dict = json.loads(response['Body'].read().decode("utf-8")) print(json.dumps(response_dict))
これでも同じようにJSONが返ってきます。
これで出来上がりです!
まとめ
以上、Sagemakerによる推論を実装してみました。
一番引っかかったのはシリアライザー、デシリアライザーのところです。何もしないとデフォルトではNPY形式でのやり取りとなるため、input_fn内でうまくデシリアライズできずにハマりました。
今回は解決策としてシリアライザー、デシリアライザーをpredictorにセットしてpredict()しましたが、他にもModelの段階でpredict_clsという引数にシリアライザー、デシリアライザーをセットしたオーバーライドクラスを用意してもうまくいきます。
推論用のAPIサーバーを独自で立てるのは結構手間なので、このエントリを参考にみなさんもご自分のモデルを手軽に試せる推論環境をSagemakerを使って構築してみてはいかがでしょうか。
参考リンク
- https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html
- https://aws.amazon.com/jp/getting-started/hands-on/build-train-deploy-machine-learning-model-sagemaker/#
- https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html
- https://stackoverflow.com/questions/63568274/how-to-use-serializer-and-deserializer-in-sagemaker-2
- https://sagemaker.readthedocs.io/en/stable/api/inference/model.html?highlight=deploy#sagemaker.model.Model.deploy
- https://sagemaker.readthedocs.io/en/stable/api/inference/serializers.html#sagemaker.serializers.BaseSerializer
- https://sagemaker.readthedocs.io/en/stable/api/inference/deserializers.html#sagemaker.deserializers.BaseDeserializer
- https://sagemaker.readthedocs.io/en/stable/v2.html