この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
1 はじめに
CX事業本部の平内(SIN)です。
AWS re:Invent 2020で発表されたAmazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できてしまいます。
以下は、PyTorch HubのResNet18のモデルをお菓子の画像データでファイチューニングしてみた例です。
今回は、ここで作成したモデルを、エッジデバイス(MBP)上のWebカメラで使用してみました。
2 入力画像
ImageNetの学習済みモデルの入力画像は、下記のように、サイズ 224*224、平均 (0.485, 0.456, 0.406)、分散 (0.229, 0.224, 0.225) で標準化したテンソルとなっています。
下記コードは、学習時に使ったものですが、こちらは、入力がPIL形式の画像を想定しています。
# 入力画像の変換
RANDOM_RESIZED_CROP = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
transform = transforms.Compose(
[
transforms.Resize(RANDOM_RESIZED_CROP),
transforms.CenterCrop(RANDOM_RESIZED_CROP),
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
)
Webカメラの画像をOpenCVで取得した場合、画像がnumpy形式となるとなるため、ここにToPILImage()を追加します。
transform = transforms.Compose(
[
transforms.ToPILImage(), # Numpy -> PIL
transforms.Resize(RANDOM_RESIZED_CROP),
transforms.CenterCrop(RANDOM_RESIZED_CROP),
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
)
3 モデル
JumpStartでファインチューニングしたモデルの配置は、以下のとおりです。 モデル本体及び、ラベルは、modelというフォルダ配下に置かれています。
4 コード
Webカメラの動画を推論しているコードは、以下のとおりです。
import json
import torch
from torch.nn import functional as F
from torchvision import transforms
import numpy as np
class Model():
def __init__(self, model_path):
self.__model = torch.load("{}/model.pt".format(model_path))
self.__model.eval() # 推論モード
RANDOM_RESIZED_CROP = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
self.__transform = transforms.Compose(
[
transforms.ToPILImage(), # Numpy -> PIL
transforms.Resize(RANDOM_RESIZED_CROP),# W*H -> 224*224
transforms.CenterCrop(RANDOM_RESIZED_CROP), # square
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
)
def inference(self, image):
inputs = self.__transform(image) # 入力形式への変換
inputs = inputs.unsqueeze(0) # バッチ次元の追加 (3,224,224) => (1,3,224,244)
outputs = self.__model(inputs) # 推論
batch_probs = F.softmax(outputs, dim=1)
batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)
return (batch_probs, batch_indices)
import cv2
# Webカメラ
DEVICE_ID = 2
WIDTH = 800
HEIGHT = 600
FPS = 24
MODEL_PATH = "./model"
def main():
cap = cv2.VideoCapture (DEVICE_ID)
# フォーマット・解像度・FPSの設定
cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
cap.set(cv2.CAP_PROP_FPS, FPS)
# モデルクラスとクラス名の初期化
model = Model(MODEL_PATH)
json_str = open("{}/class_label_to_prediction_index.json".format(MODEL_PATH), 'r')
class_names = list(json.load(json_str))
while True:
# カメラ画像取得
_, frame = cap.read()
if(frame is None):
continue
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
batch_probs, batch_indices = model.inference(image)
for probs, indices in zip(batch_probs, batch_indices):
name = class_names[indices[0].int()]
confidence = probs[0]
str = "{} {:.1f}%".format(name, confidence)
print(str)
cv2.putText(frame, str, (10, HEIGHT - 100), cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 255), 5, cv2.LINE_AA)
# 画像表示
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
5 最後に
今回は、JumpStartで作成した、PyTorchのモデルをWebカメラの入力で使用してみました。 NumpyとPILの変換以外は、特にひねりは無いです。