機械学習のモデルを変換する(PyTorchからTensorFlow Lite)
はじめに
現在、カフェのシステムでは、機械学習を用いて、カメラを用いて動画を撮影し、商品の前にいる人物の骨格や手を検出することで、どのユーザがどの商品を取り出したかを判定しています。
今までは、骨格検出モデルを用いてエッジデバイスで動画を推論処理(撮影した画像から映っている人物の骨格の座標を検出する処理)を実行する、という構成で処理をしていました。今後、エッジ側のデバイスの費用を下げたり、骨格検出以外の処理を増やすことを考えているため、エッジデバイスからクラウドに動画を送信し、クラウド側で様々な処理を実行する、という構成を検討しています。
前回までの記事で、エッジデバイスでの動画処理(エンコード・送信)と、クラウド側の処理(動画の取り出し)について記載しました。
撮影した動画をリアルタイムにエンコードする方法【GStreamer】
目的
今回は、クラウド側のメインの処理となる、機械学習のモデルを使った推論処理を、TensorFlow Liteで行うために、モデルファイルを変換する方法について調べます。今まではPyTorchのモデルを用いていたため、PyTorchからTensorFlow Liteへモデルファイルを変換する方法について記載します。
モデルを変換する方法
モデルの変換方法を調べてみると、PyTorchから直接TensorFlow Lite(以下、TFLite)に変換する方法はなく、ONNXを経由する方法が紹介されていることが多かったです。
実装したコードは以下のとおりです。途中、onnx-tfを使用するため、予め「pip install onnx-tf」でインストールしておいてください。また、converter.pyのmain部分で、models/load_model.pyのload_model_checkpoint()を利用していますが、変換したい対象のモデルをここで読み込むように、お使いの環境に合わせて修正してください(modelのクラスはtorch.nn.Moduleを継承して、ネットワークを定義しているもので、pthファイルから重みを読み込んであるものです)。
model_convert.convertの引数のうち、INPUT_SHAPEはONNX形式に変換する際のパラメータで、モデルのサイズを変更することができます。また、DO_OPTIMIZE,とOPTIMIZATIONSは、TFLiteに変換する際に指定するパラメータで、モデルの量子化(最適化)の種類を指定します。
- convert/model_convert.py
import os import subprocess import tensorflow as tf import torch def convert_torch_to_onnx(model, onnx_filepath, input_shape): dummy_input = torch.randn(input_shape) torch.onnx.export(model, dummy_input, onnx_filepath, verbose=True, input_names=['input'], output_names=['output']) assert os.path.exists(onnx_filepath) return onnx_filepath def convert_onnx_to_tensorflow(onnx_filepath, pb_filepath): subprocess.run( f"onnx-tf convert -i {onnx_filepath} -o {pb_filepath}", shell=True) assert os.path.exists(pb_filepath) def convert_tensorflow_to_tflite(pb_filepath, tflite_filepath, do_optimize, optimizations): # input_arrays = ["input"] # output_arrays = ["output"] converter = tf.lite.TFLiteConverter.from_saved_model(pb_filepath) if do_optimize: converter.optimizations = optimizations # 量子化する時のみ # converter.target_spec.supported_types = [tf.float16] # converter.inference_input_type = tf.float16 # converter.inference_output_type = tf.float16 tflite_model = converter.convert() with open(tflite_filepath, "wb") as f: f.write(tflite_model) assert os.path.exists(tflite_filepath) def convert( model, onnx_filepath, pb_filepath, tflite_filepath, input_shape, do_optimize, optimizations, ): print("################ torch to onnx ###############") if os.path.exists(onnx_filepath): print(f"already exists {onnx_filepath}") else: convert_torch_to_onnx(model, onnx_filepath, input_shape) print("################ onnx to tensorflow ################") if os.path.exists(pb_filepath): print(f"already exists {pb_filepath}") else: convert_onnx_to_tensorflow(onnx_filepath, pb_filepath) print("################ tensorflow to tflite ################") if os.path.exists(tflite_filepath): print(f"already exists {tflite_filepath}") else: convert_tensorflow_to_tflite(pb_filepath, tflite_filepath, do_optimize, optimizations) return tflite_filepath
- converter.py
import os import cv2 import tensorflow as tf from models import load_model from convert import model_convert CHECKPOINT_PATH = "../model_checkpoint/checkpoint_iter_3500.pth" DO_OPTIMIZE = True OPTIMIZATIONS = [tf.lite.Optimize.DEFAULT] MODEL_INPUT_WIDTH = 832 MODEL_INPUT_HEIGHT = 472 INPUT_SHAPE = (1, 3, MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH) IMG_PATH = "../sample/imgs/0020.jpg" def infer(tflite_filepath, img_path): interpreter = tf.lite.Interpreter(model_path=tflite_filepath) interpreter.allocate_tensors() # allocate memory # 入力層の構成情報を取得する input_details = interpreter.get_input_details() # 入力層に合わせて、画像を変換する img = cv2.imread(img_path) img = (img - 128) / 256 # 明度を正規化(モデル学習時のデータ前処理に合わせる) input_shape = input_details[0]['shape'] input_dtype = input_details[0]['dtype'] input_data = cv2.resize(img, (input_shape[2], input_shape[3])).transpose( (2, 0, 1)).reshape(input_shape).astype(input_dtype) # indexにテンソルデータのポインタをセット interpreter.set_tensor(input_details[0]['index'], input_data) # 推論実行 interpreter.invoke() # 出力層から結果を取り出す output_details = interpreter.get_output_details() output_data = interpreter.get_tensor(output_details[0]['index']) return output_data if __name__ == "__main__": CHECKPOINT_DIR = os.path.dirname(CHECKPOINT_PATH) CHECKPOINT_BASE = os.path.basename(CHECKPOINT_PATH) OUT_DIR = os.path.join(CHECKPOINT_DIR, f"{MODEL_INPUT_WIDTH}_{MODEL_INPUT_HEIGHT}") os.makedirs(OUT_DIR, exist_ok=True) onnx_filepath = os.path.join(OUT_DIR, f"{CHECKPOINT_BASE}.onnx") pb_filepath = os.path.join(OUT_DIR, f"{CHECKPOINT_BASE}.pb") tflite_filepath = os.path.join(OUT_DIR, f"{CHECKPOINT_BASE}.tflite") # convert model model = load_model.load_model_checkpoint(CHECKPOINT_PATH, use_cuda=False) # 変換したいモデルに合わせる tflite_filepath = model_convert.convert( model, onnx_filepath, pb_filepath, tflite_filepath, INPUT_SHAPE, DO_OPTIMIZE, OPTIMIZATIONS, ) # test output_data = infer(tflite_filepath, IMG_PATH) print(output_data)
infer()を実行した、output_dataの出力結果は、以下のようになっており、処理が動いたことが確認できました。
[[[[-8.4298387e-29 -6.0723546e-28 9.7049355e-30 ... 4.1735702e-29 -8.3994018e-30 -7.3987608e-31] [ 3.2861993e-28 -4.4360002e-28 4.5395772e-28 ... 1.0161931e-28 8.4211281e-29 -9.2654239e-29] [ 2.4088350e-28 -1.4278935e-28 -3.3845556e-28 ... -8.7949690e-28 -9.8673037e-28 -1.1030160e-27] ... [ 4.9395231e-29 -1.3791511e-28 5.4587205e-28 ... 8.6757241e-28 3.3906473e-28 -4.7563079e-28] [ 1.0788619e-28 4.9395231e-29 4.2192695e-28 ... -4.7563079e-28 -4.4360002e-28 -1.8961254e-27] [ 5.2846405e-28 1.3852432e-28 1.5593234e-28 ... -6.4274782e-28 -8.8228218e-28 -2.1983287e-27]] [[-3.0552871e-24 -2.6462554e-23 3.4922001e-24 ... 4.7034853e-24 2.3136525e-24 2.5755520e-24] [ 1.6914550e-23 -1.8179983e-23 2.3953099e-23 ... 7.0605812e-24 6.1766701e-24 -2.1386390e-24] [ 1.3477119e-23 -4.4957345e-24 -1.4120541e-23 ... -3.9819430e-23 -4.4631835e-23 -4.9149600e-23] ... [ 1.7243785e-24 -4.4302598e-24 2.7456004e-23 ... 4.2711653e-23 1.8387736e-23 -1.9751381e-23] [ 4.3761112e-24 4.3761112e-24 2.1661478e-23 ... -1.9424005e-23 -1.6510374e-23 -8.3654862e-23] [ 2.4967960e-23 8.4355531e-24 9.3194639e-24 ... -2.7739316e-23 -3.6742112e-23 -9.7568273e-23]] [[ 2.2140335e-25 1.2056203e-24 -5.0388737e-26 ... -1.1353236e-25 -1.4698870e-26 -2.9798432e-26] [-6.0632717e-25 8.4323078e-25 -9.2616333e-25 ... -2.3158348e-25 -1.9726629e-25 1.5139631e-25] [-4.7317647e-25 2.5022981e-25 6.3595501e-25 ... 1.7025330e-24 1.9125542e-24 2.1431657e-24] ... [ 1.8245638e-26 2.4062099e-25 -1.1073580e-24 ... -1.7415396e-24 -7.2575095e-25 8.7342995e-25] [-9.4314743e-26 -1.2863193e-25 -8.6301965e-25 ... 9.3108276e-25 8.1577704e-25 3.6339043e-24] [-9.6597122e-25 -3.0433591e-25 -3.3865311e-25 ... 1.3181807e-24 1.7080237e-24 4.2612223e-24]] ... [[ 2.7168613e-26 1.3634412e-25 -2.9802734e-27 ... -9.9845601e-27 9.7867132e-28 -6.9626682e-28] [-6.4648445e-26 9.6145603e-26 -1.0012669e-25 ... -2.3079531e-26 -1.9272853e-26 1.9402992e-26] [-4.9878541e-26 3.0366224e-26 7.3153278e-26 ... 1.8933308e-25 2.1186861e-25 2.3349053e-25] ... [ 4.6330818e-27 2.9300352e-26 -1.2022594e-25 ... -1.9057335e-25 -7.7895688e-26 9.9495488e-26] [-7.8528207e-27 -1.1659498e-26 -9.3122404e-26 ... 9.8125077e-26 8.4573309e-26 3.9641634e-25] [-1.0454244e-25 -3.1149689e-26 -3.4956365e-26 ... 1.3588732e-25 1.7745625e-25 4.6006403e-25]] [[ 3.2497330e-24 3.1778299e-23 -4.7302853e-24 ... -6.2065884e-24 -3.2938818e-24 -3.6130823e-24] [-2.1089322e-23 2.1683577e-23 -2.9667844e-23 ... -9.0793946e-24 -8.0020930e-24 2.1325308e-24] [-1.6899813e-23 5.0053374e-24 1.6735965e-23 ... 4.8057539e-23 5.3922852e-23 5.9429063e-23] ... [-2.5756802e-24 4.9255371e-24 -3.3937153e-23 ... -5.2530595e-23 -2.2884826e-23 2.3598780e-23] [-5.8075877e-24 -5.8075877e-24 -2.6874834e-23 ... 2.3199781e-23 1.9648670e-23 1.0148376e-22] [-3.0904746e-23 -1.0755199e-23 -1.1832501e-23 ... 3.3334403e-23 4.4306928e-23 1.1844129e-22]] [[ 2.5969926e-24 1.2147017e-23 -4.0252417e-26 ... -6.5294569e-25 3.0605259e-25 1.5953884e-25] [-5.4346174e-24 8.6306897e-24 -8.5380427e-24 ... -1.7984159e-24 -1.4654304e-24 1.9177025e-24] [-4.1426336e-24 2.8767006e-24 6.6194573e-24 ... 1.6782176e-23 1.8753450e-23 2.0644809e-23] ... [ 6.2571855e-25 2.7834647e-24 -1.0296206e-23 ... -1.6449777e-23 -6.5934075e-24 8.9237176e-24] [-4.6647387e-25 -7.9945935e-25 -7.9253495e-24 ... 8.8038423e-24 7.6184139e-24 3.4896587e-23] [-8.9243061e-24 -2.5043450e-24 -2.8373307e-24 ... 1.2107059e-23 1.5743261e-23 4.0464106e-23]]]]
まとめ
機械学習モデルの変換方法を調べました。ONNX形式を経由することで、PyTorchからTensorFlow Liteに変換することができました。
参考にさせていただいたページ・サイト
PyTorchのモデルをtfliteに変換する - Qiita
初心者に優しくないTensorflow Lite の公式サンプル - Qiita
Serverless ML Inferencing with AWS Lambda and TensorFlow Lite
8 Serverless deep learning · Machine Learning Bookcamp MEAP V10