この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
1 はじめに
CX事業本部の平内(SIN)です。
AWS re:Invent 2020で発表されたAmazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できてしまいます。
以下は、PyTorch HubのResNet50でファイチューニングしてみた例です。
今回は、JumpStartでファイチューニングしたモデルを、エッジデバイス(MBP)のPytorchで使用してみました。
なお、上記は、ベースモデルにResNet50を使用しましたが、今回、新たにベースをResNet18にして学習したモデルを使用しています。使用したデータセットや作業内容は、まったく同じです。
2 出力モデル
JumpStartで学習したジョブも、トレーニングジョブに列挙されています。
ここで、出力を確認すると、S3上のURLが分かります。
S3上では、圧縮されたmodel.tar.gzとして配置されおり、
ダウンロードして解凍すると、以下のように、モデル本体とクラス名情報のjsonとなっています。
3 データ
推論に使用したデータ画像は、以下の10枚です。 全て、訓練用に使用したものとは、別の画像です。
000.jpg | 001.jpg |
002.jpg | 003.jpg |
004.jpg | 005.jpg |
006.jpg | 007.jpg |
008.jpg | 009.jpg |
4 コード
推論に使用したコードです。
torch.loadでモデルをロードし、推論モードにして利用しています。
入力画像の変換のTorch Hubのドキュメント通りですが、JumpStartで学習した際に使用されたsourcedirの中にあるコードを参考にさせて頂きました。
なお、Mac上で実行しており、GPUは、使用していません。
import json
from PIL import Image
import torch
from torch.nn import functional as F
from torchvision import transforms
MODEL_PATH = "./model"
IMAGE_PATH = "./images"
# クラス名の取得
def get_class_names(model_path):
json_str = open("{}/class_label_to_prediction_index.json".format(model_path), 'r')
return list(json.load(json_str))
# モデルの取得
def get_model(model_path):
return torch.load("{}/model.pt".format(model_path))
# 入力画像の変換
RANDOM_RESIZED_CROP = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
transform = transforms.Compose(
[
transforms.Resize(RANDOM_RESIZED_CROP),
transforms.CenterCrop(RANDOM_RESIZED_CROP),
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
)
# 推論
def inference(img):
img = Image.open(img)
inputs = transform(img)
inputs = inputs.unsqueeze(0)
outputs = model(inputs)
batch_probs = F.softmax(outputs, dim=1)
batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)
return (batch_probs, batch_indices)
class_names = get_class_names(MODEL_PATH)
model = get_model(MODEL_PATH)
model.eval() # 推論モード
# 10枚の画像を推論
for i in range(10):
batch_probs, batch_indices = inference("{}/{:03d}.jpg".format(IMAGE_PATH, i))
for probs, indices in zip(batch_probs, batch_indices):
name = class_names[indices[0].int()]
confidence = probs[0]
print("{:03d}.jpg {} {:.1f}%".format(i, name, confidence * 100))
実行結果です。全部、信頼度100%ってのが、逆に不安になります。
000.jpg NOIR 100.0%
001.jpg CRATZ 100.0%
002.jpg OREO 100.0%
003.jpg NOIR 100.0%
004.jpg PRETZEL 100.0%
005.jpg ASPARAGUS 100.0%
006.jpg OREO 100.0%
007.jpg CRATZ 100.0%
008.jpg PRETZEL 100.0%
009.jpg ASPARAGUS 100.0%
5 最後に
今回は、JumpStartで作成したモデルをエッジデバイスで使用してみました。
確認したとおり、作成されたモデルは、Pytorchのモデルそのものですので、エッジデバイス等でも利用できますし、他のフレームワークへの変換も可能でしょう。
6 参考にさせて頂いたリンク
TORCHVISION.MODELS
Torch.Hub
Pytorch SAVING AND LOADING MODELS
Use PyTorch with the SageMaker Python SDK
google Crab での利用例
Pytorch – 学習済みモデルで画像分類を行う方法
PyTorchでTorch Hubを使ったpre-trained modelの共有