[Amazon SageMaker] 組み込みアルゴリズムのオブジェクト検出(ResNet-50)をJetson Nano上のMXNetで利用してみました

2020.06.23

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMake(以下、SageMaker)の組み込みアルゴリズムであるオブジェクト検出は、デバイスにインストールしたMXNetフレームワークの上で利用することが可能です。

この動作に関しては、これまで、MacOSとRaspberryPiで確認してみました。
[Amazon SageMaker] 組み込みアルゴリズムのオブジェクト検出(ResNet-50)をMac上のMXNetで利用してみました
[Amazon SageMaker] 組み込みアルゴリズムのオブジェクト検出(ResNet-50)をRaspberryPi上のMXNetで利用してみました

今回は、The NVIDIA® Jetson Nano™ 開発者キット(以下、Jetson Nano)にセットアップしたMXNetでこれを利用してみました。

最初に、動作している様子です。 GPUの使用率が99%となっており、推論は、約0.85秒です。

ちなみに、GPUを使用しない場合の推論は、3秒弱でした。

2 モデル

確認に使用したモデルは、下記のものと同じです。

SageMakerのオブジェクト検出(組み込みアルゴリズム)で作成したモデルを、MXNet用に変換しています。

3 Jetson Nano

Jetson Nanoのセットアップで使用したSDカードのイメージは、最新のJP 4.4DP 2020/04/21です。

https://developer.nvidia.com/embedded/downloads

$ uname -a
Linux nvidia-desktop 4.9.140-tegra #1 SMP PREEMPT Wed Apr 8 18:10:49 PDT 2020 aarch64 aarch64 aarch64 GNU/Linux

4 MXNet

mxnetのインストールは、下記に置かれている、mxnet_cu102_arch53-1.6.0-py2.py3-none-linux_aarch64.whlを利用させて頂きました。
http://mxnet-public.s3.amazonaws.com/

$ pip3 install https://mxnet-public.s3.amazonaws.com/install/jetson/1.6.0/mxnet_cu102_arch53-1.6.0-py2.py3-none-linux_aarch64.whl

上記のインストールだけでは、共有オブジェクトの不足が表示されていたので、追加しました。

OSError: libopenblas.so.0: cannot open shared object file: No such file or directory

$ sudo apt-get install libopenblas-base

インストールされたバージョンは、1.6.0となっており、gpuが利用できていることを確認できます。

$ python3
Python 3.6.9 (default, Apr 18 2020, 01:56:04)
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import mxnet
>>> mxnet.__version__
'1.6.0'
>>> a = mxnet.nd.ones((2, 3), mxnet.gpu())
>>> a
[[1. 1. 1.]
 [1. 1. 1.]]
<NDArray 2x3 @gpu(0)>
>>>

なお、利用時に、下記のような表示があったため、環境変数も追加しています。

enviroment variable MXNET_CUDNN_AUTOUNE_DEFAULT to 0 to desable

export MXNET_CUDNN_AUTOTUNE_DEFALUT=0

5 コード

使用しているコードです。

Jetson Nano Development KitにセットアップされているOpenCVは、GStream有効でコンパイルされているので、Webカメラの入力は、GStreamerから取得しています。

取得した画像を、入力インターフェースに合わせて変換し、推論に回されています。

import mxnet as mx
import cv2
import numpy as np
from collections import namedtuple
import time

# Webカメラ
DEVICE_ID = 0 
WIDTH = 800
HEIGHT = 600
GST_STR = ('v4l2src device=/dev/video{} ! video/x-raw, width=(int){}, height=(int){} ! videoconvert ! appsink').format(DEVICE_ID, WIDTH, HEIGHT)

MODEL_PATH = './model/deploy_model_algo_1'
CLASSES = ['ASPARA','CRATZ','PRETZEL','PRIME','OREO']
COLORS = [(128, 0, 0),(0, 128, 0),(0, 0, 128),(128, 128, 0),(0, 128,128)]

def main():

    # Model Initialize
    SHAPE = 512
    input_shapes=[('data', (1, 3, SHAPE, SHAPE))]
    Batch = namedtuple('Batch', ['data'])
    sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PATH, 0)
    #mod = mx.mod.Module(symbol=sym, label_names=[], context=mx.cpu())
    mod = mx.mod.Module(symbol=sym, label_names=[], context=mx.gpu(0))
    mod.bind(for_training=False, data_shapes=input_shapes)
    mod.set_params(arg_params, aux_params)

    # Video Initialize
    cap = cv2.VideoCapture(GST_STR, cv2.CAP_GSTREAMER)
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    fps = cap.get(cv2.CAP_PROP_FPS)
    print("fps:{} width:{} height:{}".format(fps, width, height))

    while True:
        # カメラ画像取得
        _, frame = cap.read()
        if(frame is None):
            continue

        # 入力用画像の生成
        frame = frame[0 : int(height), 0 : int(height)] # 800*600 -> 600*600
        frame = cv2.resize(frame, (SHAPE, SHAPE)) # 600*600 -> 512*512
        img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # BGR -> RGB
        img = img.transpose((2, 0, 1)) # 512,512,3 -> 3,512,512
        img = img[np.newaxis, :] # 3,512,512 -> 1,3,512,512

        # 推論
        mod.forward(Batch([mx.nd.array(img)]))
        start = time.time()
        prob = mod.get_outputs()[0].asnumpy()
        elapsed_time = time.time() - start
        prob = np.squeeze(prob)
        index = int(prob[:, 0][0])
        confidence = prob[0][1]
        x1 = int(prob[0][2] * SHAPE)
        y1 = int(prob[0][3] * SHAPE)
        x2 = int(prob[0][4] * SHAPE)
        y2 = int(prob[0][5] * SHAPE)

        # 表示
        print("[{}] {:.1f} {:.2f}[Sec] {}, {}, {}, {}".format(CLASSES[index], confidence, elapsed_time, x1, y1, x2, y2))

        if(confidence > 0.2): # 信頼度
            frame = cv2.rectangle(frame,(x1, y1), (x2, y2), COLORS[index],2)
            frame = cv2.rectangle(frame,(x1, y1), (x1 + 150,y1-20), COLORS[index], -1)
            label = "{} {:.2f}".format(CLASSES[index], confidence)
            frame = cv2.putText(frame,label,(x1+2, y1-2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)

        # 画像表示
        cv2.imshow('frame', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()

6 最後に

↑との指摘を社内で頂いたので・・・更に、確認を進めたいと思います。

7 参考にさせて頂いたリンク


Jetson Nano Developer KitでAWS IoT GreengrassのML Inferenceを試す(GPU編)
Install MXNet on a Jetson