[Amazon SageMaker] Neoで最適化したモデルが動作するJetson Nanoをローカルネットワークで利用するエンドポイントにしてみました
1 はじめに
CX事業本部の平内(SIN)です。
Amazon SageMaker Neoを使用すると、エッジデバイスに最適化したモデルを作成することが可能です。
しかし、実際にモデルを使用するクライアントアプリは、モデルを組み込んだエッジデバイスで動作していない場合もあると思います。 今回は、Neoで最適化したモデルが動作するJetson Nanoをエンドポイントにしてみました。
動画では、MacOSからHTTP(POST)で画像(Webカメラのフレーム画像)をエンドポイントに送って、推論しています。
2 モデル
下記の要領で、商品を回転台に乗せて撮影した動画から作成した、40種類お菓子を分類するモデル(イメージ分類)です。
3 Amazon SageMaker Neo
上記のモデルを下記の諸元で、SageMaker Neo最適化しています。
- Machine learning framework MXNET
- input {"data":[1,3,224,224]}
- Target device jetson_nano
最適化されたモデルは、以下のようになっています。
Jetson Nanoで動作させるための、DLRのインストール等は、下記をご参照下さい。
参考:[Amazon SageMaker] イメージ分類のモデルをNeoで最適化して、Jetson Nano+OpenCV+Webカメラで使用してみました
参考:[Amazon SageMaker] オブジェクト検出のモデルをNeoで最適化して、Jetson Nano+OpenCV+Webカメラで使用してみました
4 エンドポイント
Jetson Nano上で、エンドポイントとして動作しているコードは、以下のとおりです。POSTで画像を受信し、推論結果をJSONで返しています。
server.py
import dlr import time import json import cv2 import base64 import numpy as np from PIL import ImageFont, ImageDraw, Image from http.server import BaseHTTPRequestHandler, HTTPServer hostName = "0.0.0.0" hostPort = 9002 MODEL_PATH = './model' class Model(): def __init__(self, model_path): self.__model = dlr.DLRModel(MODEL_PATH, 'gpu') print("input_dtypes: {}".format(self.__model.get_input_dtypes())) print("input_names: {}".format(self.__model.get_input_names())) print("model initialize finished. ") print("start dry run.") for _ in range(10): dmy = np.random.rand(1, 3, 224, 224) self.__model.run({'data':dmy}) print("model ready.") def inference(self, img): # 入力画像生成 SHAPE = 224 img = cv2.resize(img, dsize=(SHAPE, SHAPE)) # height * height => 224 * 224 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR => RGB img = img.transpose((2, 0, 1)) # 224,244,3 => 3,224,224 img = img[np.newaxis, :] # 3,224,224 => 1,3,224,224 print("img.shape: {}".format(img.shape)) # 推論 start = time.time() # 時間計測 out = self.__model.run({'data': img}) processing_time = time.time() - start print(processing_time) prob = np.max(out) index = np.argmax(out[0]) print("class:{} {:3f} {:3f}[sec]".format(index, prob, processing_time)) return (index, prob, processing_time) class Server(BaseHTTPRequestHandler): def do_POST(self): try: content_len=int(self.headers.get('content-length')) body = json.loads(self.rfile.read(content_len).decode('utf-8')) img = base64.b64decode(body["image"]) img = np.fromstring(img, dtype='uint8') img = cv2.imdecode(img, 1) (index, prob, processing_time) = model.inference(img) response = { 'status' : 200, 'result' : { 'index' : int(index), 'prob' : str(prob), 'processing_time': str(processing_time) } } self.__response(response) except Exception as e: print("ERROR:{}".format(e)) response = { 'status' : 500, 'result' : { 'msg': 'Server Error' } } self.__response(response) def __response(self, body): self.send_response(200) self.send_header('Content-type', 'application/json') self.end_headers() self.wfile.write(json.dumps(body).encode('utf-8')) model = Model(MODEL_PATH) server = HTTPServer((hostName, hostPort), Server) print("start - {}:{}".format(hostName, hostPort)) server.serve_forever() server.server_close()
5 クライント
下記が、クライアント側のコードです。推論は、単純なHTTPリクエストとなています。
client.py
import numpy as np import cv2 import time import json import base64 import requests from PIL import ImageFont, ImageDraw, Image # Webカメラ DEVICE_ID = 1 WIDTH = 800 HEIGHT = 600 FPS = 24 HOST='10.0.0.11' PORT=9002 CLASSES = ["ポリッピー(GREEN)","OREO","カントリーマム","ポリッピー(RED)","柿の種(わさび)","通のとうもろこし","CHEDDER_CHEESE","ピーナッツ","ストーンチョコ","PRETZEL(YELLOW)","海味鮮","柿の種","カラフルチョコ","フルグラ(BROWN)","NOIR","BANANA(BLOWN)","チーズあられ","俺のおやつ","PRIME","CRATZ(RED)","CRATZ(GREEN)","揚一番","ポリッピー(YELLOW)","こつぶっこ","アスパラガス","海苔ピーパック","いちご","梅しそチーズあられ","通のえだ豆","柿の種(梅しそ)","PRETZEL(BLACK)","辛子明太子","CRATZ(ORANGE)","チョコメリゼ","フライドポテト(じゃがバター味)","BANANA(BLUE)","でん六豆","パズル","フルグラ(RED)","PRETZEL(GREEN)","フライドポテト(しお味)",] def request(img): _, encimg = cv2.imencode(".png", img) img = base64.b64encode(encimg.tostring()).decode("utf-8") data = json.dumps({'image': img}).encode('utf-8') response = requests.post("http://{}:{}/".format(HOST, PORT), data=data) data = json.loads(response.text) return data["result"] def putText(image, text, point, size, color): font = ImageFont.truetype("./GenShinGothic-Bold.ttf", size) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = Image.fromarray(image) draw = ImageDraw.Draw(image) draw.text(point, text, color, font) image = np.asarray(image) return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) def main(): cap = cv2.VideoCapture (DEVICE_ID) # フォーマット・解像度・FPSの設定 cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT) cap.set(cv2.CAP_PROP_FPS, FPS) # フォーマット・解像度・FPSの取得 width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) fps = cap.get(cv2.CAP_PROP_FPS) print("fps:{} width:{} height:{}".format(fps, width, height)) while True: # カメラ画像取得 _, frame = cap.read() if(frame is None): continue # Endpointへのリクエスト response = request(frame) index = int(response["index"]) prob = response["prob"] processing_time = response["processing_time"] # 結果表示 text = "{} {}".format(CLASSES[index], prob) frame = putText(frame, text, (20, int(height) - 120), 60, (255, 255, 255)) cv2.imshow('frame', frame) # キュー入力判定(1msの待機) # waitKeyがないと、imshow()は表示できない # 'q'をタイプされたらループから抜ける if cv2.waitKey(1) & 0xFF == ord('q'): break # VideoCaptureオブジェクト破棄 cap.release() cv2.destroyAllWindows() if __name__ == '__main__': main()
6 最後に
今回は、エッジデバイスをローカルネットワークから利用する為のエンドポイントにしてみました。エンドポイントとすることで、クライアントの機種やOSを問うこと無く、モデルが利用可能になります。
コスパの良いエンドポイントとして、色々応用可能かも知れません。