[Amazon SagaMaker] JumpStartのファインチューニングで作成したResNet18のモデルをエッジデバイスで使用してみました

2021.01.29

1 はじめに

CX事業本部の平内(SIN)です。

AWS re:Invent 2020で発表されたAmazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できてしまいます。

以下は、PyTorch HubのResNet50でファイチューニングしてみた例です。

今回は、JumpStartでファイチューニングしたモデルを、エッジデバイス(MBP)のPytorchで使用してみました。

なお、上記は、ベースモデルにResNet50を使用しましたが、今回、新たにベースをResNet18にして学習したモデルを使用しています。使用したデータセットや作業内容は、まったく同じです。

2 出力モデル

JumpStartで学習したジョブも、トレーニングジョブに列挙されています。

ここで、出力を確認すると、S3上のURLが分かります。

S3上では、圧縮されたmodel.tar.gzとして配置されおり、

ダウンロードして解凍すると、以下のように、モデル本体とクラス名情報のjsonとなっています。

3 データ

推論に使用したデータ画像は、以下の10枚です。 全て、訓練用に使用したものとは、別の画像です。

000.jpg

001.jpg

002.jpg

003.jpg

004.jpg

005.jpg

006.jpg

007.jpg

008.jpg

009.jpg

4 コード

推論に使用したコードです。

torch.loadでモデルをロードし、推論モードにして利用しています。

入力画像の変換のTorch Hubのドキュメント通りですが、JumpStartで学習した際に使用されたsourcedirの中にあるコードを参考にさせて頂きました。

なお、Mac上で実行しており、GPUは、使用していません。

import json
from PIL import Image
import torch
from torch.nn import functional as F
from torchvision import transforms

MODEL_PATH = "./model"
IMAGE_PATH = "./images"

# クラス名の取得
def get_class_names(model_path):
    json_str = open("{}/class_label_to_prediction_index.json".format(model_path), 'r')
    return list(json.load(json_str))

# モデルの取得
def get_model(model_path):
    return torch.load("{}/model.pt".format(model_path))

# 入力画像の変換
RANDOM_RESIZED_CROP = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
transform = transforms.Compose(
    [
        transforms.Resize(RANDOM_RESIZED_CROP),
        transforms.CenterCrop(RANDOM_RESIZED_CROP),
        transforms.ToTensor(),
        transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
    ]
)

# 推論
def inference(img):
    img = Image.open(img)
    inputs = transform(img)
    inputs = inputs.unsqueeze(0) 
    outputs = model(inputs)
    batch_probs = F.softmax(outputs, dim=1)
    batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)   
    return (batch_probs, batch_indices)

class_names = get_class_names(MODEL_PATH)
model = get_model(MODEL_PATH)
model.eval()  # 推論モード

# 10枚の画像を推論
for i in range(10):
    batch_probs, batch_indices = inference("{}/{:03d}.jpg".format(IMAGE_PATH, i))
    for probs, indices in zip(batch_probs, batch_indices):
        name = class_names[indices[0].int()]
        confidence = probs[0]
        print("{:03d}.jpg {} {:.1f}%".format(i, name, confidence * 100))

実行結果です。全部、信頼度100%ってのが、逆に不安になります。

000.jpg NOIR 100.0%
001.jpg CRATZ 100.0%
002.jpg OREO 100.0%
003.jpg NOIR 100.0%
004.jpg PRETZEL 100.0%
005.jpg ASPARAGUS 100.0%
006.jpg OREO 100.0%
007.jpg CRATZ 100.0%
008.jpg PRETZEL 100.0%
009.jpg ASPARAGUS 100.0%

5 最後に

今回は、JumpStartで作成したモデルをエッジデバイスで使用してみました。

確認したとおり、作成されたモデルは、Pytorchのモデルそのものですので、エッジデバイス等でも利用できますし、他のフレームワークへの変換も可能でしょう。

6 参考にさせて頂いたリンク


TORCHVISION.MODELS
Torch.Hub
Pytorch SAVING AND LOADING MODELS
Use PyTorch with the SageMaker Python SDK
google Crab での利用例
Pytorch – 学習済みモデルで画像分類を行う方法
PyTorchでTorch Hubを使ったpre-trained modelの共有