[Rust] PyTorchで作成したONNXモデルをBurnで変換して使う [Deep Learning]

[Rust] PyTorchで作成したONNXモデルをBurnで変換して使う [Deep Learning]

Clock Icon2023.10.19

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

Introduction

burnはRust用Deep Learningフレームワークです。
現在アクティブに開発が進められているようで、
今後が期待できるプロダクトです。

公開されているMNISTデモはこちら

今回はこのburnを用いて、ONNX形式の既存モデルを
burn用モデルに変換して使ってみます。

Burn?

burnは2021年にリリースされた新しめの深層学習フレームワークです。
少し使ってみた感じだと、PyTorchに近い感じです。

burnの特徴は、以下のとおりです。

Tensor

Tensor(テンソル)は、深層学習フレームワークを使う際の
基本的なデータ構造であり、
多次元の数値データを表現するために使用します。
burnでも例によってTensor構造体を使います。
このあたりも既存のフレームワークを使い慣れている人なら
馴染みやすいかと思います。

バックエンド

burnではだいたいの実装がBackendトレイトに基づいており、
バックエンドを切り替えることによりいろいろな実装でTensorの演算を使用できます。
現在は下記のようにTorch、ndarray、wgpuの3種類のバックエンドに対応。

  • Torch : CPU・GPUともにサポート
  • Ndarray : CPUのみ。no_stdもサポート
  • WebGPU : GPU専用。WASMも?

Ndarrayはno_std対応しているので、使用環境の幅が広がりますね。

サンプル付属

ここに、いろいろなサンプルがありますので、すぐにビルドして動作確認できます。

Datasetやimport

burnは機械学習データパイプラインの作成プロセスを
効率化するためのいろいろなデータセット実装や変換ロジック、
データソースを提供します。

また、burn-importは、ONNX形式のモデルを変換し、Burn用にRustコードを生成します。
(今回試すやつ)

その他、burnの情報についてはこことかを参考にしてください。

Environment

  • MacBook Pro (13-inch, M1, 2020)
  • OS : MacOS 13.5.2
  • Python : 3.8.8
  • Rust : 1.74.0

Try

では、PyTorchで適当なモデルを作成してONNXでexportし、
そのモデルをburnで変換して使ってみます。

PyTorchで適当なモデルを作成

まずはベースとなるモデルをPyTorchで作成。
↓みたいな適当なCSV(train_data)を学習させ、
Weightを渡したらHeightを推論するモデルを作成します。

Weight,Height
60.3,164.1
59.0,168.5
44.7,172.1
47.9,166.4
57.1,164.7
38.9,163.5
・
・
・

必要なライブラリをインストール。

pip3 install torch numpy pandas

PyTorchで学習&推論するサンプルを作成します。

import torch
import torch.nn as nn
import pandas as pd
import torch.onnx

# データセットの読み込み
df = pd.read_csv('train_data.csv')

# 入力と出力を分割
weights = df['Weight'].values
heights = df['Height'].values

# データの前処理
weights = (weights - weights.mean()) / weights.std()  # 標準化

# PyTorchのテンソルに変換
weights = torch.tensor(weights, dtype=torch.float).view(-1, 1)
heights = torch.tensor(heights, dtype=torch.float).view(-1, 1)

# モデルの定義
class RegressionModel(nn.Module):
    def __init__(self):
        super(RegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)
    def forward(self, x):
        return self.linear(x)

model = RegressionModel()

# 損失関数と最適化関数の定義
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 学習
num_epochs = 1000
for epoch in range(num_epochs):
    # forward
    outputs = model(weights)
    loss = criterion(outputs, heights)

    # backwardとパラメータ更新
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 途中結果の表示
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# モデルの推論
weights = [60.,70.,80.]
test_weights = torch.tensor(weights, dtype=torch.float32).view(-1, 1) 

predicted_height = model(test_weights)
for i in range(len(test_weights)):
    print(f'Test Sample {i + 1}:  Weights = {weights[i]} , Predicted Height={predicted_height[i][0]}')

実行してみます。
結果はともかく作成したモデルで推論が動いてます。

% python3 main.py
Epoch [100/1000], Loss: 526.2801
Epoch [200/1000], Loss: 32.2470
Epoch [300/1000], Loss: 23.5581
Epoch [400/1000], Loss: 23.4052
Epoch [500/1000], Loss: 23.4026
Epoch [600/1000], Loss: 23.4025
Epoch [700/1000], Loss: 23.4025
Epoch [800/1000], Loss: 23.4025
Epoch [900/1000], Loss: 23.4025
Epoch [1000/1000], Loss: 23.4025
Test Sample 1:  Weights = 60.0 , Predicted Height=185.85348510742188
Test Sample 2:  Weights = 70.0 , Predicted Height=189.27638244628906
Test Sample 3:  Weights = 80.0 , Predicted Height=192.69927978515625

ONNXで出力

ONNXはOpen Neural Network eXchangeの略で、
機械学習モデルを表現するためのフォーマットです。
TensorflowやPyTorchなど主要なフレームワークはONNX変換できるので、
burnでそのままimportできます。

さきほどのPyTorchのプログラムを↓のようにしてONNXで出力します。

# ダミーのinputデータを生成
dummy_input = torch.zeros(1, 1)  # ダミーのテンソル

# モデルをONNX形式にエクスポート
torch.onnx.export(model,                     # モデル
                  dummy_input,               # ダミー入力データ
                  'model.onnx',              # 出力ファイル名
                  export_params=True,        # パラメータを含めるかどうか
                  opset_version=9,           # ONNXのバージョン
                  do_constant_folding=True,  # 定数折りたたみを行うかどうか
                  input_names=['input'],     # 入力の名前
                  output_names=['output'])   # 出力の名前

Rust

Cargoでプロジェクトを作成し、さきほどのonnxファイルをコピーしておきます。

% cargo new burn_example && cd burn_example
% mkdir ./src/ptmodel
% cp /<先程のonnxファイルパス>/model.onnx ./src/ptmodel/

Cargo.tomlに依存ライブラリを記述。
build-dependenciesにburn-importを追加するのを忘れずに。

[dependencies]
burn = { version = "0.9.0", features = ["ndarray", "std", "wgpu", "tch", "train"] }
serde = "1.0.189"

[build-dependencies]
burn-import = "0.9.0"

ビルド時にonnxからrsファイルを生成するため、
burn_exampleディレクトリにbuild.rsを作成します。

use burn_import::onnx::{ModelGen, ONNXGraph};

fn main() {
    ModelGen::new()
        .input("src/ptmodel/model.onnx")
        .out_dir("ptmodel/")
        .run_from_script();
}

ModelGenを使うことで、指定したパスにある
onnxファイルを任意のディレクトリに出力できます。
上記のbuild.rsで実際にビルドすると、
burn_example/target/debug/build/burn_example-xxx/out/ptmodel
に出力されます。

ptmodel/mod.rsを作成し、
ModelGenで出力したmodel.rsファイルをincludeマクロで読み込みます。

pub mod model {
    include!(concat!(env!("OUT_DIR"), "/ptmodel/model.rs"));
}

main.rsではbuild.rsで生成されるrsファイルを使ってコードを記述します。

mod ptmodel;

use ptmodel::model::*;
use burn::tensor::Tensor;
use burn::backend::NdArrayBackend;

type Backend = NdArrayBackend;

fn main() {
    // Create Model
    let model: Model<Backend> = Model::default();
    // Create a new input tensor
    let input = Tensor::<NdArrayBackend<f32>, 2>::from_data([[60.],[70.],[80.]]);
    // Run the model
    let output = model.forward(input);
    // Print the output
    println!("{:?}", output);
}

実行してみます。推論できてますね。

% cargo run
・
・
・
Tensor { primitive: NdArrayTensor { 
array: [[185.85349],[189.27638], [192.69928]], 
shape=[3, 1], strides=[1, 1], layout=CFcf (0xf), dynamic ndim=2 } }

ちなみに、main.rsで↓のようにすれば生成されたファイルの中身がみれます。

println!("{}",include_str!(concat!(env!("OUT_DIR"), "/ptmodel/model.rs")));

Summary

今回はburnを使ってonnx形式のモデルを変換して使ってみました。
簡単にモデル変換と推論ができました。

なお、ここにburnのドキュメントがあるのですが、
目次に「8.3. WebAssembly」「8.4. No-Std」などおもしろそうなのがあるので、
(追加されたら)確認してみようかと思います。  

References

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.