[Lobe] Lobeで作成したモデルをTensorflow Lite形式でエクスポートしてMacで使用してみました

[Lobe] Lobeで作成したモデルをTensorflow Lite形式でエクスポートしてMacで使用してみました

Clock Icon2020.11.07

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX事業本部の平内(SIN)です。

前回、Microsoftによって公開されている機械学習ツール(Lobe)を試してました。

今回は、上記で作成した、モデルをTensorflow Lite形式でエクスポートして、Mac上で簡単なプログラムを作成して見ました。

最初に、作成したプログラムを実行しているようすです。

2 エクスポート

LobeのメニューからExportが選択可能です。

Tensorflow Liteを選択しています。

最適化するかどうかの選択がありますが、ここは「有効」としました。

少し、待つと、最適化及び出力が終わります。

3 出力内容

出力されたファイルは以下のとおりです。

※ Lobeでのモデル作成は、Windowsで実行しましたが、エクスポートしたファイルをMacにコピーして確認しています。

  • saved_model.tflite TFLite形式のモデル

  • signature.json ラベル名、入力形式などモデルに関する情報

  • example/tflite_example.py モデルを利用するサンプル

  • example/requirements.txt サンプル実行に必要なPythonライブラリ

  • example/README.md ドキュメント

4 サンプル実行

(1) フォルダ構成

出力されたファイルから、モデル及び、サンプルスクリプト等を以下のように配置しました。TOMATO.png 及び、AHIRU.pngは、推論の確認のために用意した画像ファイルです。

AHIRU.png

TOMATO.png

(2) 環境作成

Anacondaで作業環境(tflite)を新規に作成しています。

% conda create --name tflite python=3.7
% conda activate tflite
(tflite) %

(3) 依存関係

エクスポート出力に含まれているrequirements.txtを使用して依存関係をインストールします。

(tflite) % python -m pip install --upgrade pip && pip install -r requirements.txt

(4) tflite_runtime

tflite_runtimeが、インストールされますので、確認しておきます。

(tflite) % python3
>>> import tflite_runtime
>>> print(tflite_runtime.__version__)
2.1.0.post1

(5) 推論

サンプルスクリプトを実行すると、ちゃんと分類出来ていることが確認できます。

(tflite)  % python3 tflite_example.py ./AHIRU.png
{'Prediction': 'AHIRU', 'Confidences': [1.0, 0.0]}
(tflite)  % python3 tflite_example.py ./TOMATO.png
{'Prediction': 'TOMATO', 'Confidences': [3.0978231048070515e-14, 1.0]}

5 プログラム作成

サンプルスクリプトを参考にさせて頂いて、簡単にWebカメラの入力を推論するプログラムを作成してみました。

入力形式は、[1, 244, 244, 3] とし、signature.jsonは使用していません。また、サンプルでは、短い方(縦)の長さで正方形に切り取り、224*224にリサイズしてましたが、下記では省略してそのままリサイズしてます。(ちょっと縦横比が歪んでますが、とりあえず精度が出たので良しとしました)

# -*- coding: utf-8 -*-
import cv2
import time
import numpy as np
import tflite_runtime.interpreter as tflite

# モデル
model_file = "./saved_model.tflite"
SHAPE = 224

# Webカメラ
DEVICE_ID = 0 
WIDTH = 800
HEIGHT = 600
FPS = 24

def main():

    cap = cv2.VideoCapture (DEVICE_ID)

    # フォーマット・解像度・FPSの設定
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
    cap.set(cv2.CAP_PROP_FPS, FPS)

    # モデル初期化
    interpreter = tflite.Interpreter(model_file)
    interpreter.allocate_tensors()

    while True:

        # カメラ画像取得
        _, frame = cap.read()
        if(frame is None):
            continue

        # 入力形式に変換する [600, 800, 3] => [1, 244, 244, 3]    
        image = cv2.resize(frame, (SHAPE, SHAPE)) # => [244, 244, 3]
        image = image[np.newaxis, :] # => [1, 244, 244, 3]
        image = np.asarray(image) / 255.0 # 0..255 => 0..1
        image = image.astype(np.float32) # float64 => float32

        # 推論
        start = time.time()
        interpreter.set_tensor(0, image)
        interpreter.invoke()
        elapsed_time = time.time() - start

        # 結果表示
        output_details = interpreter.get_output_details()
        prediction = interpreter.get_tensor(output_details[0]['index'])[0].decode()
        confidences = interpreter.get_tensor(output_details[1]['index'])[0][0]
        print("{} {} {:.2f} sec".format(prediction, confidences, elapsed_time))
        frame = cv2.putText(frame, prediction, (30, HEIGHT - 100),cv2.FONT_HERSHEY_PLAIN, 10, (255, 255, 255), 10, cv2.LINE_AA)

        cv2.imshow('frame', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()

6 最後に

超簡単にモデルが作成できるLobeですが、作成したモデルは、エクスポートできることで、各種のフレームワークで利用可能です。

Object Detectionに対応するのが、待ち遠しいです。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.