[Amazon SageMaker]  JumpStartのファインチューニングで作成したResNet18のモデルをエッジデバイスのWebカメラで使用してみました

[Amazon SageMaker] JumpStartのファインチューニングで作成したResNet18のモデルをエッジデバイスのWebカメラで使用してみました

Clock Icon2021.02.03

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX事業本部の平内(SIN)です。

AWS re:Invent 2020で発表されたAmazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できてしまいます。

以下は、PyTorch HubのResNet18のモデルをお菓子の画像データでファイチューニングしてみた例です。

今回は、ここで作成したモデルを、エッジデバイス(MBP)上のWebカメラで使用してみました。

2 入力画像

ImageNetの学習済みモデルの入力画像は、下記のように、サイズ 224*224、平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化したテンソルとなっています。

下記コードは、学習時に使ったものですが、こちらは、入力がPIL形式の画像を想定しています。

# 入力画像の変換
RANDOM_RESIZED_CROP = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
transform = transforms.Compose(
    [
        transforms.Resize(RANDOM_RESIZED_CROP), 
        transforms.CenterCrop(RANDOM_RESIZED_CROP),
        transforms.ToTensor(),
        transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
    ]
)

Webカメラの画像をOpenCVで取得した場合、画像がnumpy形式となるとなるため、ここにToPILImage()を追加します。

transform = transforms.Compose(
    [
        transforms.ToPILImage(), # Numpy -> PIL
        transforms.Resize(RANDOM_RESIZED_CROP),
        transforms.CenterCrop(RANDOM_RESIZED_CROP), 
        transforms.ToTensor(), 
        transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
    ]
)

3 モデル

JumpStartでファインチューニングしたモデルの配置は、以下のとおりです。 モデル本体及び、ラベルは、modelというフォルダ配下に置かれています。

4 コード

Webカメラの動画を推論しているコードは、以下のとおりです。

import json
import torch
from torch.nn import functional as F
from torchvision import transforms
import numpy as np

class Model():
    def __init__(self, model_path):
        self.__model = torch.load("{}/model.pt".format(model_path))
        self.__model.eval()  # 推論モード

        RANDOM_RESIZED_CROP = 224
        NORMALIZE_MEAN = [0.485, 0.456, 0.406]
        NORMALIZE_STD = [0.229, 0.224, 0.225]
        self.__transform = transforms.Compose(
            [
                transforms.ToPILImage(), # Numpy -> PIL
                transforms.Resize(RANDOM_RESIZED_CROP),#  W*H -> 224*224
                transforms.CenterCrop(RANDOM_RESIZED_CROP), # square
                transforms.ToTensor(), 
                transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
            ]
        )

    def inference(self, image):
        inputs = self.__transform(image) # 入力形式への変換
        inputs = inputs.unsqueeze(0) # バッチ次元の追加 (3,224,224) => (1,3,224,244)
        outputs = self.__model(inputs) # 推論
        batch_probs = F.softmax(outputs, dim=1)
        batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)   
        return (batch_probs, batch_indices)


import cv2

# Webカメラ
DEVICE_ID = 2
WIDTH = 800
HEIGHT = 600
FPS = 24

MODEL_PATH = "./model"

def main():
    cap = cv2.VideoCapture (DEVICE_ID)

    # フォーマット・解像度・FPSの設定
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
    cap.set(cv2.CAP_PROP_FPS, FPS)

    # モデルクラスとクラス名の初期化
    model = Model(MODEL_PATH)
    json_str = open("{}/class_label_to_prediction_index.json".format(MODEL_PATH), 'r')
    class_names = list(json.load(json_str))

    while True:

        # カメラ画像取得
        _, frame = cap.read()
        if(frame is None):
            continue
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        batch_probs, batch_indices = model.inference(image)

        for probs, indices in zip(batch_probs, batch_indices):
            name = class_names[indices[0].int()]
            confidence = probs[0]
            str = "{} {:.1f}%".format(name, confidence)
            print(str)
            cv2.putText(frame, str, (10, HEIGHT - 100), cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 255), 5, cv2.LINE_AA)

        # 画像表示
        cv2.imshow('frame', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()

5 最後に

今回は、JumpStartで作成した、PyTorchのモデルをWebカメラの入力で使用してみました。 NumpyとPILの変換以外は、特にひねりは無いです。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.