[Amazon SageMaker] 組み込みアルゴリズムのイメージ分類をMac上のMXNetで利用してみました

2020.07.03

1 はじめに

CX事業本部の平内(SIN)です。

以前に、組み込みアルゴリズムの物体検出モデルをMacOSにセットアップしたMXNetで動作させて見ました。

内容は、ほとんど同じですが、今回は、同じく組み込みのイメージ分類で作成したモデルで動作させて見ました。

2 モデル

使用したモデルは、下記で作成したものです。

S3に出力されたモデルをダウンロードして解凍すると、以下のような内容になっています。

MacOS上で利用する場合に必要なファイルは、image-classification-symbol.json及び、image-classification-0023.paramsの2つです。

model-shapes.jsonは、入力の形式を示したファイルですが、こちらは、必要ありません。

[{"shape": [32, 3, 224, 224], "name": "data"}]

なお、モデルの入力ファイルとして、以下のような形式が要求されますので、ファイル名を一部変更する必要があります。

  • 名前 + "-symbol.json"
  • 名前 + "-0000.params"

3 MXNet

MacOSでは、mxnetをインストールしています。

% python3
Python 3.7.6 (default, Jan  8 2020, 13:42:34)
[Clang 4.0.1 (tags/RELEASE_401/final)] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import mxnet
>>> mxnet.__version__
'1.6.0'
>>>

4 コード

MacOS上で実行しているコードは、以下のとおりです。

MXNetは、modelフォルダの階層下にモデルを配置して初期化しています。

index.py

# -*- coding: utf-8 -*-
import cv2
import numpy as np
import mxnet as mx
from collections import namedtuple

import videoClass

# Webカメラ
DEVICE_ID = 0 
WIDTH = 800
HEIGHT = 600
FPS = 24

# モデル
MODEL_PATH='./model/image-classification'
NAMES = ["ポリッピー(GREEN)","OREO","カントリーマム","ポリッピー(RED)","CHEDDER_CHEESE","PRETZEL(YELLOW)","海味鮮","柿の種","フルグラ(BROWN)","NOIR","BANANA(BLOWN)","チーズあられ","俺のおやつ","PRIME","CRATZ(RED)","CRATZ(GREEN)","ポリッピー(YELLOW)","こつぶっこ","アスパラガス","海苔ピーパック","柿の種(梅しそ)","PRETZEL(BLACK)","CRATZ(ORANGE)","チョコメリゼ","フライドポテト(じゃがバター味)","BANANA(BLUE)","でん六豆","フルグラ(RED)","PRETZEL(GREEN)","フライドポテト(しお味)",]
SHAPE = 224

def main():
    # カメラ初期化
    video = videoClass.Video(DEVICE_ID, WIDTH, HEIGHT, FPS)

    # モデル初期化
    input_shapes=[('data', (1, 3, SHAPE, SHAPE))]
    Batch = namedtuple('Batch', ['data'])
    sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PATH, 0)
    mod = mx.mod.Module(symbol=sym, context=mx.cpu())
    mod.bind(for_training=False, data_shapes=input_shapes)
    mod.set_params(arg_params, aux_params)

    while True:
        # カメラ画像取得
        frame = video.read()
        if(frame is None):
            continue

        # 入力インターフェースへの画像変換
        height, width, channels = frame.shape[:3]
        frame = frame[0 : int(height), 0 : int(height)] # 正方形

        img = cv2.resize(frame, (SHAPE, SHAPE)) # xxx * xxx -> 224*224
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR -> RGB
        img = img.transpose((2, 0, 1)) # 512,512,3 -> 3,512,512
        img = img[np.newaxis, :] # 3,512,512 -> 1,3,512,512

        # 推論
        mod.forward(Batch([mx.nd.array(img)]))
        prob = mod.get_outputs()[0].asnumpy()
        prob = np.squeeze(prob) # 次元削除

        probabilitys = {}
        for i,result in enumerate(prob):
            probabilitys[NAMES[i]] = result
        probabilitys = sorted(probabilitys.items(), key = lambda x:x[1], reverse=True)
        (name, probability) = probabilitys[0]
        print("{} {:2f}".format(name, probability))

        # 画像表示
        if(video.show('frame', frame)==False):
            break

    del video

if __name__ == '__main__':
    main()

videoClass.py

import cv2

class Video():
    def __init__(self, deviceId, width, height, fps):
        self.__cap = cv2.VideoCapture (deviceId)

        # フォーマット・解像度・FPSの設定
        self.__cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
        self.__cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
        self.__cap.set(cv2.CAP_PROP_FPS, fps)

        # フォーマット・解像度・FPSの取得
        self.__fourcc = self.__decode_fourcc(self.__cap.get(cv2.CAP_PROP_FOURCC))
        self.__width = self.__cap.get(cv2.CAP_PROP_FRAME_WIDTH)
        self.__height = self.__cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
        self.__fps = self.__cap.get(cv2.CAP_PROP_FPS)
        print("fourcc:{} fps:{} width:{} height:{}".format(self.__fourcc, self.__fps, self.__width, self.__height))

    def __del__(self):
        self.__cap.release()
        cv2.destroyAllWindows()

    def __decode_fourcc(self, v):
        v = int(v)
        return "".join([chr((v >> 8 * i) & 0xFF) for i in range(4)])

    def read(self):
        _, frame = self.__cap.read()
        return frame

    def show(self, name, frame):
        cv2.imshow(name, frame)
        # キュー入力判定(1msの待機)
        # waitKeyがないと、imshow()は表示できない
        # 'q'をタイプされたらループから抜ける
        if cv2.waitKey(1) & 0xFF == ord('q'):
            return False
        return True

5 最後に

今回は、イメージ分類で作成したモデルをMacOS 上で利用してみました。

ちょっとした動作確認に、SageMakerでエンドポイントを立ち上げるのは、費用的にも時間的にも勿体ないので、手元で確認できる環境を作っておくと重宝するかもしれません。