[Amazon SageMaker] JumpStartのファインチューニングで作成したResNet18のモデルをJetson NanoのWebカメラで使用してみました

2021.02.13

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMaker JumpStart(以下、JumpStart)は、TensorFlow Hub や PyTorch Hub に公開されているモデルをGUIで簡単にデプロイして利用できます。

以下は、PyTorch HubのResNet50でファインチューニングしてみた例です。

今回は、上記と同じ要領でResNet18から学習したモデルをJetson Nanoで使用してみました。

最初に、動作している様子です。推論は、GPU上で行われており、処理時間は0.04sec 〜 0.07sec程度で動作しています。

2 PyTorchとtorchvisionのインストール

使用した、Jetsonは、JetPack 4.4.1で動作しています。JetsonにPytorchの環境を構築する要領は、下記の公式ページで案内されています。
PyTorch for Jetson - version 1.7.0 now available

使用したのは、JetPack4.4用にコンパイル済みで公開されているPyTorch 1.7.0です。

$ wget https://nvidia.box.com/shared/static/cs3xn3td6sfgtene6jdvsxlr366m2dhq.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl
$ sudo apt-get install python3-pip libopenblas-base libopenmpi-dev 
$ pip3 install Cython
$ pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl

また、PyTorch 1.7に対応するtorchvisionは、v0.8.1とのことですので、バージョンを指定してコンパイル・インストールしています。

$ sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libavcodec-dev libavformat-dev libswscale-dev
$ git clone --branch v0.8.1 https://github.com/pytorch/vision torchvision
$ cd torchvision
$ export BUILD_VERSION=0.8.1
$ python3 setup.py install --user
$ cd ../  
$ pip3 install 'pillow<7'

以下は、インストールされた状況を確認しているようすです。

nvidia@nvidia-desktop:~/pytorch$ python3
Python 3.6.9 (default, Oct  8 2020, 12:12:24)
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'1.7.0'
>>> import torchvision
>>> torchvision.__version__
'0.8.1'
>>> import PIL
>>> PIL.__version__
'6.2.2'

3 モデル

JumpStartでファインチューニングしたモデルの配置は、以下のとおりです。 モデル本体及び、ラベルは、modelというフォルダ配下に置かれています。

4 コード

Jetson Nano上で、動作しているコードです。

Webカメラの入力は、GStreamer経由でOpenCVで処理されています。 入力画像は、学習時に使用したtransforms.Composeをそのまま使用して、入力形式に合わせています。

import json
import time
import torch
from torch.nn import functional as F
from torchvision import transforms
import numpy as np

class Model():
    def __init__(self, model_path):
        self.__model = torch.load("{}/model.pt".format(model_path))
        self.__model.eval()  # 推論モード

        self.__device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("device:{}".format(self.__device))
        self.__model.to(self.__device)

        RANDOM_RESIZED_CROP = 224
        NORMALIZE_MEAN = [0.485, 0.456, 0.406]
        NORMALIZE_STD = [0.229, 0.224, 0.225]
        self.__transform = transforms.Compose(
            [
                transforms.ToPILImage(), # Numpy -> PIL
                transforms.Resize(RANDOM_RESIZED_CROP),#  W*H -> 224*224
                transforms.CenterCrop(RANDOM_RESIZED_CROP), # square
                transforms.ToTensor(), 
                transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
            ]
        )

    def inference(self, image):
        inputs = self.__transform(image) # 入力形式への変換
        inputs = inputs.unsqueeze(0) # バッチ次元の追加 (3,224,224) => (1,3,224,244)
        inputs = inputs.to(self.__device)
        outputs = self.__model(inputs) # 推論
        batch_probs = F.softmax(outputs, dim=1)
        batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)   
        return (batch_probs, batch_indices)


import cv2

# Webカメラ
DEVICE_ID = 0
WIDTH = 800
HEIGHT = 600
GST_STR = ('v4l2src device=/dev/video{} ! video/x-raw, width=(int){}, height=(int){} ! videoconvert ! appsink').format(DEVICE_ID, WIDTH, HEIGHT)
MODEL_PATH = "./model"

def main():
    cap = cv2.VideoCapture(GST_STR, cv2.CAP_GSTREAMER)

    # モデルクラスとクラス名の初期化
    model = Model(MODEL_PATH)
    json_str = open("{}/class_label_to_prediction_index.json".format(MODEL_PATH), 'r')
    class_names = list(json.load(json_str))
    print("class_names:{}".format(class_names))

    while True:

        # カメラ画像取得
        _, frame = cap.read()
        if(frame is None):
            continue
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        start = time.time()
        batch_probs, batch_indices = model.inference(image)
        processing_time = time.time() - start

        for probs, indices in zip(batch_probs, batch_indices):
            name = class_names[indices[0].int()]
            confidence = probs[0]
            str = "{} {:.1f}% {:.2f} sec".format(name, confidence, processing_time)
            print(str)
            cv2.putText(frame, str, (10, HEIGHT - 100), cv2.FONT_HERSHEY_PLAIN, 4, (255, 255, 255), 5, cv2.LINE_AA)

        # 画像表示
        cv2.imshow('frame', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()

5 最後に

JumpStartでファインチューニングした、ResNet18のモデルを、Jetson Nanoで利用してみました。

JetPackは、現在、最新が4.5になっていますが、こちらに対応したPyTorchのコンパイル済みモジュールは、まだ、提供されていなかったため、今回は、4.4としました。

ちょっと弱気ですが・・・公式で提供されている手順だと、つまずきが無くて助かります。