[Amazon SageMaker] Neoで最適化したモデルが動作するJetson Nanoをローカルネットワークで利用するエンドポイントにしてみました

2020.07.09

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMaker Neoを使用すると、エッジデバイスに最適化したモデルを作成することが可能です。

しかし、実際にモデルを使用するクライアントアプリは、モデルを組み込んだエッジデバイスで動作していない場合もあると思います。 今回は、Neoで最適化したモデルが動作するJetson Nanoをエンドポイントにしてみました。

動画では、MacOSからHTTP(POST)で画像(Webカメラのフレーム画像)をエンドポイントに送って、推論しています。

2 モデル

下記の要領で、商品を回転台に乗せて撮影した動画から作成した、40種類お菓子を分類するモデル(イメージ分類)です。

3 Amazon SageMaker Neo

上記のモデルを下記の諸元で、SageMaker Neo最適化しています。

  • Machine learning framework MXNET
  • input {"data":[1,3,224,224]}
  • Target device jetson_nano

最適化されたモデルは、以下のようになっています。

Jetson Nanoで動作させるための、DLRのインストール等は、下記をご参照下さい。
参考:[Amazon SageMaker] イメージ分類のモデルをNeoで最適化して、Jetson Nano+OpenCV+Webカメラで使用してみました
参考:[Amazon SageMaker] オブジェクト検出のモデルをNeoで最適化して、Jetson Nano+OpenCV+Webカメラで使用してみました

4 エンドポイント

Jetson Nano上で、エンドポイントとして動作しているコードは、以下のとおりです。POSTで画像を受信し、推論結果をJSONで返しています。

server.py

import dlr
import time
import json
import cv2
import base64
import numpy as np
from PIL import ImageFont, ImageDraw, Image
from http.server import BaseHTTPRequestHandler, HTTPServer

hostName = "0.0.0.0"
hostPort = 9002
MODEL_PATH = './model'

class Model():
    def __init__(self, model_path):
        self.__model = dlr.DLRModel(MODEL_PATH, 'gpu')
        print("input_dtypes: {}".format(self.__model.get_input_dtypes()))
        print("input_names: {}".format(self.__model.get_input_names()))
        print("model initialize finished. ")
        print("start dry run.")
        for _ in range(10):
            dmy = np.random.rand(1, 3, 224, 224)
            self.__model.run({'data':dmy})
        print("model ready.")


    def inference(self, img):

        # 入力画像生成
        SHAPE = 224
        img = cv2.resize(img, dsize=(SHAPE, SHAPE)) # height * height => 224 * 224
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR => RGB
        img = img.transpose((2, 0, 1)) # 224,244,3 => 3,224,224
        img = img[np.newaxis, :] # 3,224,224 => 1,3,224,224
        print("img.shape: {}".format(img.shape))

        # 推論
        start = time.time() # 時間計測
        out = self.__model.run({'data': img})
        processing_time = time.time() - start
        print(processing_time)
        prob = np.max(out)
        index = np.argmax(out[0])
        print("class:{} {:3f}  {:3f}[sec]".format(index, prob, processing_time))
        return (index, prob, processing_time)

class Server(BaseHTTPRequestHandler):

    def do_POST(self):
        try:
            content_len=int(self.headers.get('content-length'))
            body = json.loads(self.rfile.read(content_len).decode('utf-8'))
            img = base64.b64decode(body["image"])
            img = np.fromstring(img, dtype='uint8')
            img = cv2.imdecode(img, 1)

            (index, prob, processing_time) = model.inference(img)
            response = { 
                'status' : 200,
                'result' : { 
                    'index' : int(index),
                    'prob' : str(prob),
                    'processing_time': str(processing_time)
                }
            }
            self.__response(response)
        except Exception as e:
            print("ERROR:{}".format(e))
            response = { 
                'status' : 500,
                'result' : { 'msg': 'Server Error' }
            }
            self.__response(response)

    def __response(self, body):
            self.send_response(200)
            self.send_header('Content-type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(body).encode('utf-8'))

model = Model(MODEL_PATH)

server = HTTPServer((hostName, hostPort), Server)
print("start - {}:{}".format(hostName, hostPort))
server.serve_forever()
server.server_close()

5 クライント

下記が、クライアント側のコードです。推論は、単純なHTTPリクエストとなています。

client.py

import numpy as np
import cv2
import time
import json
import base64
import requests
from PIL import ImageFont, ImageDraw, Image

# Webカメラ
DEVICE_ID = 1 

WIDTH = 800
HEIGHT = 600
FPS = 24

HOST='10.0.0.11'
PORT=9002

CLASSES = ["ポリッピー(GREEN)","OREO","カントリーマム","ポリッピー(RED)","柿の種(わさび)","通のとうもろこし","CHEDDER_CHEESE","ピーナッツ","ストーンチョコ","PRETZEL(YELLOW)","海味鮮","柿の種","カラフルチョコ","フルグラ(BROWN)","NOIR","BANANA(BLOWN)","チーズあられ","俺のおやつ","PRIME","CRATZ(RED)","CRATZ(GREEN)","揚一番","ポリッピー(YELLOW)","こつぶっこ","アスパラガス","海苔ピーパック","いちご","梅しそチーズあられ","通のえだ豆","柿の種(梅しそ)","PRETZEL(BLACK)","辛子明太子","CRATZ(ORANGE)","チョコメリゼ","フライドポテト(じゃがバター味)","BANANA(BLUE)","でん六豆","パズル","フルグラ(RED)","PRETZEL(GREEN)","フライドポテト(しお味)",]

def request(img):
  _, encimg = cv2.imencode(".png", img)
  img = base64.b64encode(encimg.tostring()).decode("utf-8")
  data = json.dumps({'image': img}).encode('utf-8')

  response = requests.post("http://{}:{}/".format(HOST, PORT), data=data)
  data = json.loads(response.text)
  return data["result"]

def putText(image, text, point, size, color):
    font = ImageFont.truetype("./GenShinGothic-Bold.ttf", size)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(image)
    draw = ImageDraw.Draw(image)
    draw.text(point, text, color, font)
    image = np.asarray(image)
    return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

def main():
    cap = cv2.VideoCapture (DEVICE_ID)

    # フォーマット・解像度・FPSの設定
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
    cap.set(cv2.CAP_PROP_FPS, FPS)

    # フォーマット・解像度・FPSの取得
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    fps = cap.get(cv2.CAP_PROP_FPS)
    print("fps:{} width:{} height:{}".format(fps, width, height))

    while True:

        # カメラ画像取得
        _, frame = cap.read()
        if(frame is None):
            continue
        # Endpointへのリクエスト
        response = request(frame)
        index = int(response["index"])
        prob = response["prob"]
        processing_time = response["processing_time"]

        # 結果表示
        text = "{} {}".format(CLASSES[index], prob)
        frame = putText(frame, text, (20, int(height) - 120), 60, (255, 255, 255))
        cv2.imshow('frame', frame)

        # キュー入力判定(1msの待機)
        # waitKeyがないと、imshow()は表示できない
        # 'q'をタイプされたらループから抜ける
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # VideoCaptureオブジェクト破棄
    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()

6 最後に

今回は、エッジデバイスをローカルネットワークから利用する為のエンドポイントにしてみました。エンドポイントとすることで、クライアントの機種やOSを問うこと無く、モデルが利用可能になります。

コスパの良いエンドポイントとして、色々応用可能かも知れません。