[Amazon SageMaker] JumpStartのファインチューニングで作成したResNet18のモデルをJetson NanoのWebカメラで使用してみました
1 はじめに
CX事業本部の平内(SIN)です。
Amazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できます。
以下は、PyTorch HubのResNet50でファインチューニングしてみた例です。
今回は、上記と同じ要領でResNet18から学習したモデルをJetson Nanoで使用してみました。
最初に、動作している様子です。推論は、GPU上で行われており、処理時間は0.04sec 〜 0.07sec程度で動作しています。
2 PyTorchとtorchvisionのインストール
使用した、Jetsonは、JetPack 4.4.1で動作しています。JetsonにPytorchの環境を構築する要領は、下記の公式ページで案内されています。
PyTorch for Jetson - version 1.7.0 now available
使用したのは、JetPack4.4用にコンパイル済みで公開されているPyTorch 1.7.0です。
$ wget https://nvidia.box.com/shared/static/cs3xn3td6sfgtene6jdvsxlr366m2dhq.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl $ sudo apt-get install python3-pip libopenblas-base libopenmpi-dev $ pip3 install Cython $ pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl
また、PyTorch 1.7に対応するtorchvisionは、v0.8.1とのことですので、バージョンを指定してコンパイル・インストールしています。
$ sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libavcodec-dev libavformat-dev libswscale-dev $ git clone --branch v0.8.1 https://github.com/pytorch/vision torchvision $ cd torchvision $ export BUILD_VERSION=0.8.1 $ python3 setup.py install --user $ cd ../ $ pip3 install 'pillow<7'
以下は、インストールされた状況を確認しているようすです。
nvidia@nvidia-desktop:~/pytorch$ python3 Python 3.6.9 (default, Oct 8 2020, 12:12:24) [GCC 8.4.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import torch >>> torch.__version__ '1.7.0' >>> import torchvision >>> torchvision.__version__ '0.8.1' >>> import PIL >>> PIL.__version__ '6.2.2'
3 モデル
JumpStartでファインチューニングしたモデルの配置は、以下のとおりです。 モデル本体及び、ラベルは、modelというフォルダ配下に置かれています。
4 コード
Jetson Nano上で、動作しているコードです。
Webカメラの入力は、GStreamer経由でOpenCVで処理されています。 入力画像は、学習時に使用したtransforms.Composeをそのまま使用して、入力形式に合わせています。
import json import time import torch from torch.nn import functional as F from torchvision import transforms import numpy as np class Model(): def __init__(self, model_path): self.__model = torch.load("{}/model.pt".format(model_path)) self.__model.eval() # 推論モード self.__device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device:{}".format(self.__device)) self.__model.to(self.__device) RANDOM_RESIZED_CROP = 224 NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] self.__transform = transforms.Compose( [ transforms.ToPILImage(), # Numpy -> PIL transforms.Resize(RANDOM_RESIZED_CROP),# W*H -> 224*224 transforms.CenterCrop(RANDOM_RESIZED_CROP), # square transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), ] ) def inference(self, image): inputs = self.__transform(image) # 入力形式への変換 inputs = inputs.unsqueeze(0) # バッチ次元の追加 (3,224,224) => (1,3,224,244) inputs = inputs.to(self.__device) outputs = self.__model(inputs) # 推論 batch_probs = F.softmax(outputs, dim=1) batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True) return (batch_probs, batch_indices) import cv2 # Webカメラ DEVICE_ID = 0 WIDTH = 800 HEIGHT = 600 GST_STR = ('v4l2src device=/dev/video{} ! video/x-raw, width=(int){}, height=(int){} ! videoconvert ! appsink').format(DEVICE_ID, WIDTH, HEIGHT) MODEL_PATH = "./model" def main(): cap = cv2.VideoCapture(GST_STR, cv2.CAP_GSTREAMER) # モデルクラスとクラス名の初期化 model = Model(MODEL_PATH) json_str = open("{}/class_label_to_prediction_index.json".format(MODEL_PATH), 'r') class_names = list(json.load(json_str)) print("class_names:{}".format(class_names)) while True: # カメラ画像取得 _, frame = cap.read() if(frame is None): continue image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) start = time.time() batch_probs, batch_indices = model.inference(image) processing_time = time.time() - start for probs, indices in zip(batch_probs, batch_indices): name = class_names[indices[0].int()] confidence = probs[0] str = "{} {:.1f}% {:.2f} sec".format(name, confidence, processing_time) print(str) cv2.putText(frame, str, (10, HEIGHT - 100), cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 255), 5, cv2.LINE_AA) # 画像表示 cv2.imshow('frame', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() if __name__ == '__main__': main()
5 最後に
JumpStartでファインチューニングした、ResNet18のモデルを、Jetson Nanoで利用してみました。
JetPackは、現在、最新が4.5になっていますが、こちらに対応したPyTorchのコンパイル済みモジュールは、まだ、提供されていなかったため、今回は、4.4としました。
ちょっと弱気ですが・・・公式で提供されている手順だと、つまずきが無くて助かります。