[Amazon SageMaker] オブジェクト検出のモデルをWebカメラで確認してみました

[Amazon SageMaker] オブジェクト検出のモデルをWebカメラで確認してみました

Clock Icon2020.05.04

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMaker(以下、SageMaker)の組み込みアルゴリズムであるオブジェクト検出で作成したモデルをMacに接続したWebカメラで確認してみました。

最初に、作業している様子です。

ml.m4.4xlargeにデプロイして使用していますが、3個のラベル(画像80枚)を学習しただけの小さなモデルなので、比較的高速に動作しており動画としてもなんとか違和感無いものとなっていると思います。

2 エンドポイント作成

モデルは、S3上に配置し、AWS SDK for Python による下記のコードでデプロイしてます。

  • モデルの作成
  • エンドポイント構成の作成
  • エンドポイントの作成
import boto3
from boto3.session import Session
from sagemaker.amazon.amazon_estimator import get_image_uri

# S3上のモデルのURL
modelUrl = "s3://sagemaker-working-bucket/my-sample/output/model.tar.gz"
# モデルのホスティングに使用するRoleのARN
role = 'arn:aws:iam::xxxxxxxxxxxx:role/service-role/AmazonSageMaker-ExecutionRole-20200208T113545'
profile = "developer"
region = "ap-northeast-1"

class SageMaker():
    def __init__(self, profile, region, role):
        self.__session = Session(profile_name=profile, region_name=region)
        self.__client = self.__session.client('sagemaker')
        self.__role = role

        self.__modelName = 'sampleModel'
        self.__configName = 'sampleConfig'
        self.__endPointName = 'sampleEndPoint'

    def delete(self):
        self.__client.delete_endpoint(
            EndpointName = self.__endPointName
        )
        self.__client.delete_endpoint_config(
            EndpointConfigName = self.__configName
        )
        self.__client.delete_model(
            ModelName = self.__modelName
        )

    def create(self, modelUrl):
        self.__createModel(modelUrl)
        self.__createConfig()
        self.__createEndPoint()

    # モデル作成
    def __createModel(self,  modelUrl):
        imageName = get_image_uri(region, 'object-detection', repo_version="latest")
        model_params = {
            "ExecutionRoleArn": self.__role, 
            "ModelName": self.__modelName,
            "PrimaryContainer": {
                "Image": imageName, 
                "ModelDataUrl": modelUrl
            }
        }
        response = self.__client.create_model(**model_params)
        print("create model:{} {}".format(self.__modelName, response["ModelArn"]))

    # エンドポイント構成の作成
    def __createConfig(self):
        response = self.__client.create_endpoint_config(
            EndpointConfigName = self.__configName,
            ProductionVariants=[
                {
                    'VariantName': 'VariantName',
                    'ModelName': self.__modelName,
                    'InitialInstanceCount': 1,
                    'InstanceType': 'ml.m4.4xlarge'
                },
            ]
        )
        print("create config:{} {}".format(self.__configName, response["EndpointConfigArn"]))

    # エンドポイントの作成
    def __createEndPoint(self):
        response = self.__client.create_endpoint(
            EndpointName = self.__endPointName,
            EndpointConfigName = self.__configName
        )
        print("create endPoint:{} {}".format(self.__endPointName, response["EndpointArn"]))

sageMake = SageMaker(profile, region, role)

# Endpointの作成
sageMake.create(modelUrl)

# Endpointの削除
#sageMake.delete()

モデルエンドポイント構成、及び、エンドポイントの名前は、コードの中で、定義されており、以下のようになっています。

モデル(sampleModel)

エンドポイント構成(sampleConfig) エンドポイント(sampleEndPoint)

生成されたエンドポイントのStatusが、inServiceになったら利用可能です。(8〜10分かかります)

3 テスト

OpenCVでWebカメラの画像を取得し、エンドポイントに投げて、結果を描画しているコードは、以下の通りです。 信頼度は、0.6(confidence > 0.6)以上のものだけが枠線表示の対象になっています。

from boto3.session import Session
import json
import cv2

profile = 'developer'
endPoint = 'sampleEndPoint'
categories = ['Bisco','BlackThunder','Alfort']

deviceId = 1 # Webカメラのデバイスインデックス
height = 600
width = 800
linewidth = 2
colors = [(0,0,175),(175,0,0),(0,175,0)]

class SageMaker():
    def __init__(self, profile, endPoint):
        self.__endPoint = endPoint
        session = Session(profile_name = profile)
        self.__client = session.client('sagemaker-runtime')

    def invoke(self, image):
        data = self.__client.invoke_endpoint(
            EndpointName = self.__endPoint,
            Body = image,
            ContentType='image/jpeg'
        )
        results = data['Body'].read()
        return json.loads(results)

sageMake = SageMaker(profile, endPoint)

cap = cv2.VideoCapture(deviceId)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

fps = cap.get(cv2.CAP_PROP_FPS)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
print("FPS:{} WIDTH:{} HEIGHT:{}".format(fps, width, height))

while True:

    # カメラ画像取得
    ret, frame = cap.read()
    if(frame is None):
        continue

    _, jpg = cv2.imencode('.jpg', frame)
    detections = sageMake.invoke(jpg.tostring())

    for detection in detections["prediction"]:
        clsId = int(detection[0])
        confidence = detection[1]
        x1 = int(detection[2] * width)
        y1 = int(detection[3] * height)
        x2 = int(detection[4] * width)
        y2 = int(detection[5] * height)
        label = "{} {:.2f}".format(categories[clsId], confidence)
        if(confidence > 0.6): # 信頼度
            frame = cv2.rectangle(frame,(x1, y1), (x2, y2), colors[clsId],linewidth)
            frame = cv2.rectangle(frame,(x1, y1), (x1 + 150,y1-20), colors[clsId], -1)
            cv2.putText(frame,label,(x1+2, y1-2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)

    cv2.imshow('frame', frame)
    cv2.waitKey(1)

cap.release()
cv2.destroyAllWindows()

4 エンドポイント削除

先のコードでsageMake.delete()を有効にして実行すると生成したリソースは全て削除されます。

  • エンドポイント(sampleEndPoint)
  • エンドポイント構成(sampleConfig)
  • モデル(sampleModel)

# (略)

class SageMaker():

# (略)

sageMake = SageMaker(profile, region, role)

# Endpointの削除
sageMake.delete()

コンソールからは、Endpointsが削除されている事を確認できます。

5 最後に

今回は、オブジェクト検出のモデルをWebカメラで確認してみました。

オブジェクト検出では、「多数の正解ラベルと画像で一気に精度を確認する」というような手法は難しく、どうしても、目視による確認となってしまいます。

もし、検出したい対象物が、Webカメラで簡単に撮影出来るのであれば、このような方法が手軽でいいかも知れません。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.