機械学習のモデルを変換する(PyTorchからTensorFlow Lite)

2021.06.02

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

はじめに

現在、カフェのシステムでは、機械学習を用いて、カメラを用いて動画を撮影し、商品の前にいる人物の骨格や手を検出することで、どのユーザがどの商品を取り出したかを判定しています。

今までは、骨格検出モデルを用いてエッジデバイスで動画を推論処理(撮影した画像から映っている人物の骨格の座標を検出する処理)を実行する、という構成で処理をしていました。今後、エッジ側のデバイスの費用を下げたり、骨格検出以外の処理を増やすことを考えているため、エッジデバイスからクラウドに動画を送信し、クラウド側で様々な処理を実行する、という構成を検討しています。

前回までの記事で、エッジデバイスでの動画処理(エンコード・送信)と、クラウド側の処理(動画の取り出し)について記載しました。

撮影した動画をリアルタイムにエンコードする方法【GStreamer】

【Kinesis Video Streams】Pythonで動画ファイルを送信する

Lambda関数でKinesis Video Streamsから動画を取得する

目的

今回は、クラウド側のメインの処理となる、機械学習のモデルを使った推論処理を、TensorFlow Liteで行うために、モデルファイルを変換する方法について調べます。今まではPyTorchのモデルを用いていたため、PyTorchからTensorFlow Liteへモデルファイルを変換する方法について記載します。

モデルを変換する方法

モデルの変換方法を調べてみると、PyTorchから直接TensorFlow Lite(以下、TFLite)に変換する方法はなく、ONNXを経由する方法が紹介されていることが多かったです。

実装したコードは以下のとおりです。途中、onnx-tfを使用するため、予め「pip install onnx-tf」でインストールしておいてください。また、converter.pyのmain部分で、models/load_model.pyのload_model_checkpoint()を利用していますが、変換したい対象のモデルをここで読み込むように、お使いの環境に合わせて修正してください(modelのクラスはtorch.nn.Moduleを継承して、ネットワークを定義しているもので、pthファイルから重みを読み込んであるものです)。

model_convert.convertの引数のうち、INPUT_SHAPEはONNX形式に変換する際のパラメータで、モデルのサイズを変更することができます。また、DO_OPTIMIZE,とOPTIMIZATIONSは、TFLiteに変換する際に指定するパラメータで、モデルの量子化(最適化)の種類を指定します。

  • convert/model_convert.py
import os
import subprocess
import tensorflow as tf
import torch

def convert_torch_to_onnx(model, onnx_filepath, input_shape):
    dummy_input = torch.randn(input_shape)
    torch.onnx.export(model, dummy_input, onnx_filepath,
                      verbose=True, input_names=['input'], output_names=['output'])
    assert os.path.exists(onnx_filepath)

    return onnx_filepath

def convert_onnx_to_tensorflow(onnx_filepath, pb_filepath):
    subprocess.run(
        f"onnx-tf convert -i {onnx_filepath} -o {pb_filepath}", shell=True)
    assert os.path.exists(pb_filepath)

def convert_tensorflow_to_tflite(pb_filepath, tflite_filepath, do_optimize, optimizations):
    # input_arrays = ["input"]
    # output_arrays = ["output"]

    converter = tf.lite.TFLiteConverter.from_saved_model(pb_filepath)
    if do_optimize:
        converter.optimizations = optimizations  # 量子化する時のみ
        # converter.target_spec.supported_types = [tf.float16]
        # converter.inference_input_type = tf.float16
        # converter.inference_output_type = tf.float16
    tflite_model = converter.convert()

    with open(tflite_filepath, "wb") as f:
        f.write(tflite_model)
    assert os.path.exists(tflite_filepath)

def convert(
    model, onnx_filepath, pb_filepath, tflite_filepath,
    input_shape, do_optimize, optimizations,
):
    print("################ torch to onnx ###############")
    if os.path.exists(onnx_filepath):
        print(f"already exists {onnx_filepath}")
    else:
        convert_torch_to_onnx(model, onnx_filepath, input_shape)

    print("################ onnx to tensorflow ################")
    if os.path.exists(pb_filepath):
        print(f"already exists {pb_filepath}")
    else:
        convert_onnx_to_tensorflow(onnx_filepath, pb_filepath)

    print("################ tensorflow to tflite ################")
    if os.path.exists(tflite_filepath):
        print(f"already exists {tflite_filepath}")
    else:
        convert_tensorflow_to_tflite(pb_filepath, tflite_filepath, do_optimize, optimizations)

    return tflite_filepath
  • converter.py
import os
import cv2
import tensorflow as tf

from models import load_model
from convert import model_convert

CHECKPOINT_PATH = "../model_checkpoint/checkpoint_iter_3500.pth"
DO_OPTIMIZE = True
OPTIMIZATIONS = [tf.lite.Optimize.DEFAULT]

MODEL_INPUT_WIDTH = 832
MODEL_INPUT_HEIGHT =  472
INPUT_SHAPE = (1, 3, MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH)

IMG_PATH = "../sample/imgs/0020.jpg"

def infer(tflite_filepath, img_path):
    interpreter = tf.lite.Interpreter(model_path=tflite_filepath)
    interpreter.allocate_tensors()  # allocate memory

    # 入力層の構成情報を取得する
    input_details = interpreter.get_input_details()

    # 入力層に合わせて、画像を変換する
    img = cv2.imread(img_path)
    img = (img - 128) / 256  # 明度を正規化(モデル学習時のデータ前処理に合わせる)
    input_shape = input_details[0]['shape']
    input_dtype = input_details[0]['dtype']
    input_data = cv2.resize(img, (input_shape[2], input_shape[3])).transpose(
        (2, 0, 1)).reshape(input_shape).astype(input_dtype)
    # indexにテンソルデータのポインタをセット
    interpreter.set_tensor(input_details[0]['index'], input_data)

    # 推論実行
    interpreter.invoke()

    # 出力層から結果を取り出す
    output_details = interpreter.get_output_details()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    return output_data

if __name__ == "__main__":
    CHECKPOINT_DIR = os.path.dirname(CHECKPOINT_PATH)
    CHECKPOINT_BASE = os.path.basename(CHECKPOINT_PATH)
    OUT_DIR = os.path.join(CHECKPOINT_DIR, f"{MODEL_INPUT_WIDTH}_{MODEL_INPUT_HEIGHT}")
    os.makedirs(OUT_DIR, exist_ok=True)

    onnx_filepath = os.path.join(OUT_DIR, f"{CHECKPOINT_BASE}.onnx")
    pb_filepath = os.path.join(OUT_DIR, f"{CHECKPOINT_BASE}.pb")
    tflite_filepath = os.path.join(OUT_DIR, f"{CHECKPOINT_BASE}.tflite")

    # convert model
    model = load_model.load_model_checkpoint(CHECKPOINT_PATH, use_cuda=False) # 変換したいモデルに合わせる
    tflite_filepath = model_convert.convert(
        model, onnx_filepath, pb_filepath, tflite_filepath,
        INPUT_SHAPE, DO_OPTIMIZE, OPTIMIZATIONS,
    )

    # test
    output_data = infer(tflite_filepath, IMG_PATH)
    print(output_data)

infer()を実行した、output_dataの出力結果は、以下のようになっており、処理が動いたことが確認できました。

[[[[-8.4298387e-29 -6.0723546e-28  9.7049355e-30 ...  4.1735702e-29
    -8.3994018e-30 -7.3987608e-31]
   [ 3.2861993e-28 -4.4360002e-28  4.5395772e-28 ...  1.0161931e-28
     8.4211281e-29 -9.2654239e-29]
   [ 2.4088350e-28 -1.4278935e-28 -3.3845556e-28 ... -8.7949690e-28
    -9.8673037e-28 -1.1030160e-27]
   ...
   [ 4.9395231e-29 -1.3791511e-28  5.4587205e-28 ...  8.6757241e-28
     3.3906473e-28 -4.7563079e-28]
   [ 1.0788619e-28  4.9395231e-29  4.2192695e-28 ... -4.7563079e-28
    -4.4360002e-28 -1.8961254e-27]
   [ 5.2846405e-28  1.3852432e-28  1.5593234e-28 ... -6.4274782e-28
    -8.8228218e-28 -2.1983287e-27]]

  [[-3.0552871e-24 -2.6462554e-23  3.4922001e-24 ...  4.7034853e-24
     2.3136525e-24  2.5755520e-24]
   [ 1.6914550e-23 -1.8179983e-23  2.3953099e-23 ...  7.0605812e-24
     6.1766701e-24 -2.1386390e-24]
   [ 1.3477119e-23 -4.4957345e-24 -1.4120541e-23 ... -3.9819430e-23
    -4.4631835e-23 -4.9149600e-23]
   ...
   [ 1.7243785e-24 -4.4302598e-24  2.7456004e-23 ...  4.2711653e-23
     1.8387736e-23 -1.9751381e-23]
   [ 4.3761112e-24  4.3761112e-24  2.1661478e-23 ... -1.9424005e-23
    -1.6510374e-23 -8.3654862e-23]
   [ 2.4967960e-23  8.4355531e-24  9.3194639e-24 ... -2.7739316e-23
    -3.6742112e-23 -9.7568273e-23]]

  [[ 2.2140335e-25  1.2056203e-24 -5.0388737e-26 ... -1.1353236e-25
    -1.4698870e-26 -2.9798432e-26]
   [-6.0632717e-25  8.4323078e-25 -9.2616333e-25 ... -2.3158348e-25
    -1.9726629e-25  1.5139631e-25]
   [-4.7317647e-25  2.5022981e-25  6.3595501e-25 ...  1.7025330e-24
     1.9125542e-24  2.1431657e-24]
   ...
   [ 1.8245638e-26  2.4062099e-25 -1.1073580e-24 ... -1.7415396e-24
    -7.2575095e-25  8.7342995e-25]
   [-9.4314743e-26 -1.2863193e-25 -8.6301965e-25 ...  9.3108276e-25
     8.1577704e-25  3.6339043e-24]
   [-9.6597122e-25 -3.0433591e-25 -3.3865311e-25 ...  1.3181807e-24
     1.7080237e-24  4.2612223e-24]]

  ...

  [[ 2.7168613e-26  1.3634412e-25 -2.9802734e-27 ... -9.9845601e-27
     9.7867132e-28 -6.9626682e-28]
   [-6.4648445e-26  9.6145603e-26 -1.0012669e-25 ... -2.3079531e-26
    -1.9272853e-26  1.9402992e-26]
   [-4.9878541e-26  3.0366224e-26  7.3153278e-26 ...  1.8933308e-25
     2.1186861e-25  2.3349053e-25]
   ...
   [ 4.6330818e-27  2.9300352e-26 -1.2022594e-25 ... -1.9057335e-25
    -7.7895688e-26  9.9495488e-26]
   [-7.8528207e-27 -1.1659498e-26 -9.3122404e-26 ...  9.8125077e-26
     8.4573309e-26  3.9641634e-25]
   [-1.0454244e-25 -3.1149689e-26 -3.4956365e-26 ...  1.3588732e-25
     1.7745625e-25  4.6006403e-25]]

  [[ 3.2497330e-24  3.1778299e-23 -4.7302853e-24 ... -6.2065884e-24
    -3.2938818e-24 -3.6130823e-24]
   [-2.1089322e-23  2.1683577e-23 -2.9667844e-23 ... -9.0793946e-24
    -8.0020930e-24  2.1325308e-24]
   [-1.6899813e-23  5.0053374e-24  1.6735965e-23 ...  4.8057539e-23
     5.3922852e-23  5.9429063e-23]
   ...
   [-2.5756802e-24  4.9255371e-24 -3.3937153e-23 ... -5.2530595e-23
    -2.2884826e-23  2.3598780e-23]
   [-5.8075877e-24 -5.8075877e-24 -2.6874834e-23 ...  2.3199781e-23
     1.9648670e-23  1.0148376e-22]
   [-3.0904746e-23 -1.0755199e-23 -1.1832501e-23 ...  3.3334403e-23
     4.4306928e-23  1.1844129e-22]]

  [[ 2.5969926e-24  1.2147017e-23 -4.0252417e-26 ... -6.5294569e-25
     3.0605259e-25  1.5953884e-25]
   [-5.4346174e-24  8.6306897e-24 -8.5380427e-24 ... -1.7984159e-24
    -1.4654304e-24  1.9177025e-24]
   [-4.1426336e-24  2.8767006e-24  6.6194573e-24 ...  1.6782176e-23
     1.8753450e-23  2.0644809e-23]
   ...
   [ 6.2571855e-25  2.7834647e-24 -1.0296206e-23 ... -1.6449777e-23
    -6.5934075e-24  8.9237176e-24]
   [-4.6647387e-25 -7.9945935e-25 -7.9253495e-24 ...  8.8038423e-24
     7.6184139e-24  3.4896587e-23]
   [-8.9243061e-24 -2.5043450e-24 -2.8373307e-24 ...  1.2107059e-23
     1.5743261e-23  4.0464106e-23]]]]

まとめ

機械学習モデルの変換方法を調べました。ONNX形式を経由することで、PyTorchからTensorFlow Liteに変換することができました。

参考にさせていただいたページ・サイト

PyTorchのモデルを別形式に変換する方法いろいろ(TorchScript, ONNX, TensorRT, CoreML, OpenVINO, Tensorflow, TFLite) - Qiita

PyTorchのモデルをtfliteに変換する - Qiita

初心者に優しくないTensorflow Lite の公式サンプル - Qiita

Kerasのモデルをtfliteに変換する

Serverless ML Inferencing with AWS Lambda and TensorFlow Lite

8 Serverless deep learning · Machine Learning Bookcamp MEAP V10