【小ネタ】[Amazon SageMaker] 既存のモデルのデプロイをJupyter Notebookでやってみました

2020.04.23

1 はじめに

CX事業本部の平内(SIN)です

SageMaker Python SDKには、Modelクラスがあり、これをエンドポイントにデプロイできます。 今回は、Jupyter Notebook上で、過去に作成したモデル(ビルトインの物体検出)をデプロイする要領を確認してみました。
参考:https://sagemaker.readthedocs.io/en/stable/model.html

Jupyter Notebookのサンプルでは、データセットから学習して出来上がったモデルを使用して作業するパターンが殆どで、既存モデルから生成できるModelクラスの使い方に、ちょっと戸惑ったので、その覚書です。

2 Jupyter Notebook

Jupyter Notebookの内容は、以下の通りです。

(1) Setup

最初に、SageMakerのセッションや、ロールを準備します。

import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

role = get_execution_role()
sess = sagemaker.Session()

(2) Create Model

ここで、Modelクラスのインスタンスを生成しています。

Dockerイメージ(training_image)は、当該モデルを作成したビルトインのobject-detectionです。 また、使用する既存モデル(model.tar.gz)は、S3に配置する必要があります。

ちょっと注意が必要なのは、predictor_clsを指定していないと、deploy()Noneが返され、推論するための識別子を利用できないことです。

from sagemaker.model import Model
from sagemaker.predictor import RealTimePredictor, json_deserializer

class ImagePredictor(RealTimePredictor):
    def __init__(self, endpoint_name, sagemaker_session):
        super().__init__(endpoint_name, sagemaker_session=sagemaker_session, serializer=None, 
                         deserializer=json_deserializer, content_type='image/jpeg')
training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")
model_data = 's3://sagemaker-working-bucket/Sweets/output/model.tar.gz'

model = Model(role =role,image=training_image,model_data = model_data, predictor_cls=ImagePredictor, sagemaker_session=sess)

(3) Deploy

インスタンスの種類と数を指定してエンドポイントを生成します。

object_detector = model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')

(4) Detection

下記は、サンプルコードであるawslabs/amazon-sagemaker-examplesと同じで、テスト画像に結果を描画してます。

import json
​
def visualize_detection(img_file, dets, classes=[], thresh=0.1):
        import random
        import matplotlib.pyplot as plt
        import matplotlib.image as mpimg
​
        img=mpimg.imread(img_file)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for det in dets:
            (klass, score, x0, y0, x1, y1) = det
            if score < thresh:
                continue
            cls_id = int(klass)
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            xmin = int(x0 * width)
            ymin = int(y0 * height)
            xmax = int(x1 * width)
            ymax = int(y1 * height)
            rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                 ymax - ymin, fill=False,
                                 edgecolor=colors[cls_id],
                                 linewidth=3.5)
            plt.gca().add_patch(rect)
            class_name = str(cls_id)
            if classes and len(classes) > cls_id:
                class_name = classes[cls_id]
            plt.gca().text(xmin, ymin - 2,
                            '{:s} {:.3f}'.format(class_name, score),
                            bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                                    fontsize=12, color='white')
        plt.show()
file_name = 'TestData_Sweets/sweet001.png'
​
with open(file_name, 'rb') as image:
    f = image.read()
    b = bytearray(f)
    ne = open('n.txt','wb')
    ne.write(b)
object_detector.content_type = 'image/jpeg'
detections = object_detector.predict(b)
​
print(detections)
object_categories = ['BlackThunder','HomePie','Bisco']
​
threshold = 0.2
visualize_detection(file_name, detections['prediction'], object_categories, threshold)

(5) Delete Endpoint

次のコードでエンドポイントを削除しています。

sagemaker.Session().delete_endpoint(object_detector.endpoint)

3 最後に

今回は、既存のモデルのデプロイをJupyter Notebookでやってみました。何をするにも必要となるModelクラスの利用方法は、しっかり掴みたいと思います。

コードは、下記に起きました。
https://gist.github.com/furuya02/9ecbc1773aff4536f113e2ab8fa6097e