この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
1 はじめに
CX事業本部の平内(SIN)です。
Amazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できます。
以下は、PyTorch HubのResNet50でファインチューニングしてみた例です。
今回は、上記と同じ要領でResNet18から学習したモデルをJetson Nanoで使用してみました。
最初に、動作している様子です。推論は、GPU上で行われており、処理時間は0.04sec 〜 0.07sec程度で動作しています。
2 PyTorchとtorchvisionのインストール
使用した、Jetsonは、JetPack 4.4.1で動作しています。JetsonにPytorchの環境を構築する要領は、下記の公式ページで案内されています。
PyTorch for Jetson - version 1.7.0 now available
使用したのは、JetPack4.4用にコンパイル済みで公開されているPyTorch 1.7.0です。
$ wget https://nvidia.box.com/shared/static/cs3xn3td6sfgtene6jdvsxlr366m2dhq.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl
$ sudo apt-get install python3-pip libopenblas-base libopenmpi-dev
$ pip3 install Cython
$ pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl
また、PyTorch 1.7に対応するtorchvisionは、v0.8.1とのことですので、バージョンを指定してコンパイル・インストールしています。
$ sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libavcodec-dev libavformat-dev libswscale-dev
$ git clone --branch v0.8.1 https://github.com/pytorch/vision torchvision
$ cd torchvision
$ export BUILD_VERSION=0.8.1
$ python3 setup.py install --user
$ cd ../
$ pip3 install 'pillow<7'
以下は、インストールされた状況を確認しているようすです。
nvidia@nvidia-desktop:~/pytorch$ python3
Python 3.6.9 (default, Oct 8 2020, 12:12:24)
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'1.7.0'
>>> import torchvision
>>> torchvision.__version__
'0.8.1'
>>> import PIL
>>> PIL.__version__
'6.2.2'
3 モデル
JumpStartでファインチューニングしたモデルの配置は、以下のとおりです。 モデル本体及び、ラベルは、modelというフォルダ配下に置かれています。
4 コード
Jetson Nano上で、動作しているコードです。
Webカメラの入力は、GStreamer経由でOpenCVで処理されています。 入力画像は、学習時に使用したtransforms.Composeをそのまま使用して、入力形式に合わせています。
import json
import time
import torch
from torch.nn import functional as F
from torchvision import transforms
import numpy as np
class Model():
def __init__(self, model_path):
self.__model = torch.load("{}/model.pt".format(model_path))
self.__model.eval() # 推論モード
self.__device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:{}".format(self.__device))
self.__model.to(self.__device)
RANDOM_RESIZED_CROP = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
self.__transform = transforms.Compose(
[
transforms.ToPILImage(), # Numpy -> PIL
transforms.Resize(RANDOM_RESIZED_CROP),# W*H -> 224*224
transforms.CenterCrop(RANDOM_RESIZED_CROP), # square
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]
)
def inference(self, image):
inputs = self.__transform(image) # 入力形式への変換
inputs = inputs.unsqueeze(0) # バッチ次元の追加 (3,224,224) => (1,3,224,244)
inputs = inputs.to(self.__device)
outputs = self.__model(inputs) # 推論
batch_probs = F.softmax(outputs, dim=1)
batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)
return (batch_probs, batch_indices)
import cv2
# Webカメラ
DEVICE_ID = 0
WIDTH = 800
HEIGHT = 600
GST_STR = ('v4l2src device=/dev/video{} ! video/x-raw, width=(int){}, height=(int){} ! videoconvert ! appsink').format(DEVICE_ID, WIDTH, HEIGHT)
MODEL_PATH = "./model"
def main():
cap = cv2.VideoCapture(GST_STR, cv2.CAP_GSTREAMER)
# モデルクラスとクラス名の初期化
model = Model(MODEL_PATH)
json_str = open("{}/class_label_to_prediction_index.json".format(MODEL_PATH), 'r')
class_names = list(json.load(json_str))
print("class_names:{}".format(class_names))
while True:
# カメラ画像取得
_, frame = cap.read()
if(frame is None):
continue
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
start = time.time()
batch_probs, batch_indices = model.inference(image)
processing_time = time.time() - start
for probs, indices in zip(batch_probs, batch_indices):
name = class_names[indices[0].int()]
confidence = probs[0]
str = "{} {:.1f}% {:.2f} sec".format(name, confidence, processing_time)
print(str)
cv2.putText(frame, str, (10, HEIGHT - 100), cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 255), 5, cv2.LINE_AA)
# 画像表示
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
5 最後に
JumpStartでファインチューニングした、ResNet18のモデルを、Jetson Nanoで利用してみました。
JetPackは、現在、最新が4.5になっていますが、こちらに対応したPyTorchのコンパイル済みモジュールは、まだ、提供されていなかったため、今回は、4.4としました。
ちょっと弱気ですが・・・公式で提供されている手順だと、つまずきが無くて助かります。