[Amazon SagaMaker] JumpStartのファインチューニングで作成したResNet18のモデルをエッジデバイスで使用してみました
1 はじめに
CX事業本部の平内(SIN)です。
AWS re:Invent 2020で発表されたAmazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できてしまいます。
以下は、PyTorch HubのResNet50でファイチューニングしてみた例です。
今回は、JumpStartでファイチューニングしたモデルを、エッジデバイス(MBP)のPytorchで使用してみました。
なお、上記は、ベースモデルにResNet50を使用しましたが、今回、新たにベースをResNet18にして学習したモデルを使用しています。使用したデータセットや作業内容は、まったく同じです。
2 出力モデル
JumpStartで学習したジョブも、トレーニングジョブに列挙されています。
ここで、出力を確認すると、S3上のURLが分かります。
S3上では、圧縮されたmodel.tar.gzとして配置されおり、
ダウンロードして解凍すると、以下のように、モデル本体とクラス名情報のjsonとなっています。
3 データ
推論に使用したデータ画像は、以下の10枚です。 全て、訓練用に使用したものとは、別の画像です。
000.jpg | 001.jpg |
002.jpg | 003.jpg |
004.jpg | 005.jpg |
006.jpg | 007.jpg |
008.jpg | 009.jpg |
4 コード
推論に使用したコードです。
torch.loadでモデルをロードし、推論モードにして利用しています。
入力画像の変換のTorch Hubのドキュメント通りですが、JumpStartで学習した際に使用されたsourcedirの中にあるコードを参考にさせて頂きました。
なお、Mac上で実行しており、GPUは、使用していません。
import json from PIL import Image import torch from torch.nn import functional as F from torchvision import transforms MODEL_PATH = "./model" IMAGE_PATH = "./images" # クラス名の取得 def get_class_names(model_path): json_str = open("{}/class_label_to_prediction_index.json".format(model_path), 'r') return list(json.load(json_str)) # モデルの取得 def get_model(model_path): return torch.load("{}/model.pt".format(model_path)) # 入力画像の変換 RANDOM_RESIZED_CROP = 224 NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] transform = transforms.Compose( [ transforms.Resize(RANDOM_RESIZED_CROP), transforms.CenterCrop(RANDOM_RESIZED_CROP), transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), ] ) # 推論 def inference(img): img = Image.open(img) inputs = transform(img) inputs = inputs.unsqueeze(0) outputs = model(inputs) batch_probs = F.softmax(outputs, dim=1) batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True) return (batch_probs, batch_indices) class_names = get_class_names(MODEL_PATH) model = get_model(MODEL_PATH) model.eval() # 推論モード # 10枚の画像を推論 for i in range(10): batch_probs, batch_indices = inference("{}/{:03d}.jpg".format(IMAGE_PATH, i)) for probs, indices in zip(batch_probs, batch_indices): name = class_names[indices[0].int()] confidence = probs[0] print("{:03d}.jpg {} {:.1f}%".format(i, name, confidence * 100))
実行結果です。全部、信頼度100%ってのが、逆に不安になります。
000.jpg NOIR 100.0% 001.jpg CRATZ 100.0% 002.jpg OREO 100.0% 003.jpg NOIR 100.0% 004.jpg PRETZEL 100.0% 005.jpg ASPARAGUS 100.0% 006.jpg OREO 100.0% 007.jpg CRATZ 100.0% 008.jpg PRETZEL 100.0% 009.jpg ASPARAGUS 100.0%
5 最後に
今回は、JumpStartで作成したモデルをエッジデバイスで使用してみました。
確認したとおり、作成されたモデルは、Pytorchのモデルそのものですので、エッジデバイス等でも利用できますし、他のフレームワークへの変換も可能でしょう。
6 参考にさせて頂いたリンク
TORCHVISION.MODELS
Torch.Hub
Pytorch SAVING AND LOADING MODELS
Use PyTorch with the SageMaker Python SDK
google Crab での利用例
Pytorch – 学習済みモデルで画像分類を行う方法
PyTorchでTorch Hubを使ったpre-trained modelの共有