[Amazon SageMaker] JumpStartのファインチューニングで作成したResNet18のモデルをエッジデバイスのWebカメラで使用してみました
1 はじめに
CX事業本部の平内(SIN)です。
AWS re:Invent 2020で発表されたAmazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できてしまいます。
以下は、PyTorch HubのResNet18のモデルをお菓子の画像データでファイチューニングしてみた例です。
今回は、ここで作成したモデルを、エッジデバイス(MBP)上のWebカメラで使用してみました。
2 入力画像
ImageNetの学習済みモデルの入力画像は、下記のように、サイズ 224*224、平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化したテンソルとなっています。
下記コードは、学習時に使ったものですが、こちらは、入力がPIL形式の画像を想定しています。
# 入力画像の変換 RANDOM_RESIZED_CROP = 224 NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] transform = transforms.Compose( [ transforms.Resize(RANDOM_RESIZED_CROP), transforms.CenterCrop(RANDOM_RESIZED_CROP), transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), ] )
Webカメラの画像をOpenCVで取得した場合、画像がnumpy形式となるとなるため、ここにToPILImage()を追加します。
transform = transforms.Compose( [ transforms.ToPILImage(), # Numpy -> PIL transforms.Resize(RANDOM_RESIZED_CROP), transforms.CenterCrop(RANDOM_RESIZED_CROP), transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), ] )
3 モデル
JumpStartでファインチューニングしたモデルの配置は、以下のとおりです。 モデル本体及び、ラベルは、modelというフォルダ配下に置かれています。
4 コード
Webカメラの動画を推論しているコードは、以下のとおりです。
import json import torch from torch.nn import functional as F from torchvision import transforms import numpy as np class Model(): def __init__(self, model_path): self.__model = torch.load("{}/model.pt".format(model_path)) self.__model.eval() # 推論モード RANDOM_RESIZED_CROP = 224 NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] self.__transform = transforms.Compose( [ transforms.ToPILImage(), # Numpy -> PIL transforms.Resize(RANDOM_RESIZED_CROP),# W*H -> 224*224 transforms.CenterCrop(RANDOM_RESIZED_CROP), # square transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), ] ) def inference(self, image): inputs = self.__transform(image) # 入力形式への変換 inputs = inputs.unsqueeze(0) # バッチ次元の追加 (3,224,224) => (1,3,224,244) outputs = self.__model(inputs) # 推論 batch_probs = F.softmax(outputs, dim=1) batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True) return (batch_probs, batch_indices) import cv2 # Webカメラ DEVICE_ID = 2 WIDTH = 800 HEIGHT = 600 FPS = 24 MODEL_PATH = "./model" def main(): cap = cv2.VideoCapture (DEVICE_ID) # フォーマット・解像度・FPSの設定 cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT) cap.set(cv2.CAP_PROP_FPS, FPS) # モデルクラスとクラス名の初期化 model = Model(MODEL_PATH) json_str = open("{}/class_label_to_prediction_index.json".format(MODEL_PATH), 'r') class_names = list(json.load(json_str)) while True: # カメラ画像取得 _, frame = cap.read() if(frame is None): continue image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) batch_probs, batch_indices = model.inference(image) for probs, indices in zip(batch_probs, batch_indices): name = class_names[indices[0].int()] confidence = probs[0] str = "{} {:.1f}%".format(name, confidence) print(str) cv2.putText(frame, str, (10, HEIGHT - 100), cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 255), 5, cv2.LINE_AA) # 画像表示 cv2.imshow('frame', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() if __name__ == '__main__': main()
5 最後に
今回は、JumpStartで作成した、PyTorchのモデルをWebカメラの入力で使用してみました。 NumpyとPILの変換以外は、特にひねりは無いです。