【小ネタ】[Amazon SageMaker] 既存のモデルのデプロイをJupyter Notebookでやってみました

2020.04.23

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX事業本部の平内(SIN)です

SageMaker Python SDKには、Modelクラスがあり、これをエンドポイントにデプロイできます。 今回は、Jupyter Notebook上で、過去に作成したモデル(ビルトインの物体検出)をデプロイする要領を確認してみました。

参考:https://sagemaker.readthedocs.io/en/stable/model.html

Jupyter Notebookのサンプルでは、データセットから学習して出来上がったモデルを使用して作業するパターンが殆どで、既存モデルから生成できるModelクラスの使い方に、ちょっと戸惑ったので、その覚書です。

2 Jupyter Notebook

Jupyter Notebookの内容は、以下の通りです。

(1) Setup

最初に、SageMakerのセッションや、ロールを準備します。

import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

role = get_execution_role()
sess = sagemaker.Session()

(2) Create Model

ここで、Modelクラスのインスタンスを生成しています。

Dockerイメージ(training_image)は、当該モデルを作成したビルトインのobject-detectionです。 また、使用する既存モデル(model.tar.gz)は、S3に配置する必要があります。

ちょっと注意が必要なのは、predictor_clsを指定していないと、deploy()Noneが返され、推論するための識別子を利用できないことです。

from sagemaker.model import Model
from sagemaker.predictor import RealTimePredictor, json_deserializer

class ImagePredictor(RealTimePredictor):
def __init__(self, endpoint_name, sagemaker_session):
super().__init__(endpoint_name, sagemaker_session=sagemaker_session, serializer=None,
deserializer=json_deserializer, content_type='image/jpeg')
training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")
model_data = 's3://sagemaker-working-bucket/Sweets/output/model.tar.gz'

model = Model(role =role,image=training_image,model_data = model_data, predictor_cls=ImagePredictor, sagemaker_session=sess)

(3) Deploy

インスタンスの種類と数を指定してエンドポイントを生成します。

object_detector = model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')

(4) Detection

下記は、サンプルコードであるawslabs/amazon-sagemaker-examplesと同じで、テスト画像に結果を描画してます。

import json
​
def visualize_detection(img_file, dets, classes=[], thresh=0.1):
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
​
img=mpimg.imread(img_file)
plt.imshow(img)
height = img.shape[0]
width = img.shape[1]
colors = dict()
for det in dets:
(klass, score, x0, y0, x1, y1) = det
if score < thresh: continue cls_id = int(klass) if cls_id not in colors: colors[cls_id] = (random.random(), random.random(), random.random()) xmin = int(x0 * width) ymin = int(y0 * height) xmax = int(x1 * width) ymax = int(y1 * height) rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor=colors[cls_id], linewidth=3.5) plt.gca().add_patch(rect) class_name = str(cls_id) if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
fontsize=12, color='white')
plt.show()
file_name = 'TestData_Sweets/sweet001.png'
​
with open(file_name, 'rb') as image:
f = image.read()
b = bytearray(f)
ne = open('n.txt','wb')
ne.write(b)
object_detector.content_type = 'image/jpeg'
detections = object_detector.predict(b)
​
print(detections)
object_categories = ['BlackThunder','HomePie','Bisco']
​
threshold = 0.2
visualize_detection(file_name, detections['prediction'], object_categories, threshold)

(5) Delete Endpoint

次のコードでエンドポイントを削除しています。

sagemaker.Session().delete_endpoint(object_detector.endpoint)

3 最後に

今回は、既存のモデルのデプロイをJupyter Notebookでやってみました。何をするにも必要となるModelクラスの利用方法は、しっかり掴みたいと思います。

コードは、下記に起きました。

https://gist.github.com/furuya02/9ecbc1773aff4536f113e2ab8fa6097e