[Rust] ortでonnxモデルを使って推論したりWASMにしたりしてみる

2024.06.06

Introduction

以前、BurnというRustの機械学習フレームワークで
ONNXファイルを変換して使うという記事を書きました。
問題なく変換して推論までできましたが、onnxファイルをそのまま使いたいケースもあります。
というわけで、今回はONNX RuntimeのRustラッパー「ort」を使ってみます。

また、wasmpackを使ってWASMにしてChrome Extensionから使ってみます。

[補足] ONNX?

ONNXは、さまざまな機械学習フレームワーク間で使用できる共通フォーマットです。
これを使うことにより、PytorchでトレーニングしたモデルをTensorFlowで使う
みたいなことが容易にできます。

ort?

ortは、ONNXランタイム用のRustバインディングです。
ここで紹介されていますが、ortとONNX Runtimeを併用することで、
さまざまなMLモデル (YOLOv8、BERT、LLaMAなど) を(ほぼ)すべてのハードウェア上で実行でき、
さらに多くのケースでPyTorchよりも高速に実行させることができます。
(機械学習モデルをONNXグラフに変換することで最適化も可能となる)

Environment

  • MacBook Pro (14-inch, M3, 2023)
  • OS : MacOS 14.5
  • Rust : 1.78.0
  • wasm-pack 0.12.1
  • gh : 2.49.2

Try

では、ortをつかってMNISTの数値認識をやってみましょう。 まずはcargoでプロジェクトを作成します。

% cargo new ort-rs
% cd ort-rs

Cargo.tomlは↓のようになってます。
cargo addでortいれるとversion1が入るので注意。

[dependencies]
image = "0.25.1"
ndarray = "0.15.6"
ort = "2.0.0-rc.2"
tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] }

このあたりからmnistのonnxファイルを持ってきます。
あとはテストで確認するための数値画像はこのへんとかで用意。
それらのファイルはプロジェクトroot(ort-rsの下)においておきます。

main.rsはこんな感じです。
ort::Sessionでmnist.onnxをloadして、
ndarrayで画像を変換後にSession::runで推論を実行します。

use ort::{GraphOptimizationLevel, Session, Value};
use std::error::Error;
use ndarray::{Array, Ix4, ArrayD};
use std::collections::HashMap;
use image::io::Reader as ImageReader;

/// メイン関数
fn main() -> Result<(), Box<dyn Error>> {
    // ログの初期化
    tracing_subscriber::fmt::init();

    // モデルのセッションを作成
    let model = create_session(include_bytes!("../mnist.onnx"))?;

    // 画像をロードして前処理を行う(8の画像)
    let img_array = preprocess_image(include_bytes!("../mnist_8.jpg"))?;

    // 推論を実行して結果を表示
    run_inference(&model, img_array)?;

    Ok(())
}

fn create_session(model_data: &[u8]) -> Result<Session, Box<dyn Error>> {
    let session = Session::builder()?
        .with_optimization_level(GraphOptimizationLevel::Level3)?
        .with_intra_threads(4)?
        .commit_from_memory(model_data)?;
    Ok(session)
}

fn preprocess_image(img_data: &[u8]) -> Result<Array<f32, Ix4>, Box<dyn Error>> {
    // 画像をロード
    let img = ImageReader::new(std::io::Cursor::new(img_data))
        .with_guessed_format()?
        .decode()?
        .to_luma8();

    // 画像を28x28にリサイズ
    let img = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Nearest);

    // 画像データを正規化してndarrayに変換
    let img_array = Array::from_shape_vec(
        (1, 1, 28, 28),
        img.iter().map(|&p| p as f32 / 255.0).collect(),
    )?;
    Ok(img_array)
}

fn run_inference(model: &Session, img_array: Array<f32, Ix4>) -> Result<(), Box<dyn Error>> {
    // 入力データをHashMapに格納
    let input_tensor = Value::from_array(img_array)?;
    let mut inputs = HashMap::new();
    inputs.insert("Input3", input_tensor);

    // 推論を実行
    let outputs = model.run(inputs)?;

    // 結果を表示
    for (name, tensor) in outputs.iter() {
        println!("Output {}: {:?}", name, tensor);

        let array_view: ndarray::ArrayViewD<f32> = tensor.try_extract_tensor()?;
        let array: ArrayD<f32> = array_view.to_owned();
        println!("Tensor values: {:?}", array);

        // 最も高い値を持つインデックスを見つける
        let max_index = array.iter()
            .enumerate()
            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
            .map(|(index, _)| index)
            .unwrap();

        println!("Predicted digit: {}", max_index);
    }

    Ok(())
}

ちなみに、入力パラメータ名って何つかえばいいかわからなかったので、
GenAIに聞いたら↓で調べろといわれた。

import onnx

# ONNXモデルをロード
model = onnx.load("mnist.onnx")

# モデルの入力名を表示
for input in model.graph.input:
    print(input.name)

実行するとこんな感じです。
一応、画像はちゃんと判定されてますね。

% cargo run

Output Plus214_Output_0: Value { ・・・ }

Tensor values: [[1.1330107, -5.077395, 8.586635, 4.278495, -7.7753954, 1.1661655, -4.3016477, -10.987418, 14.0324135, 1.9482136]], shape=[1, 10], strides=[10, 1], layout=CFcf (0xf), dynamic ndim=2

Predicted digit: 8

Convert to WASM & Use in Chrome Extension

では次に、mnistプログラムをWASM化して
Chrome Extensionで使ってみます。
onnxのWASM化はまだexperimentalとのことですが、
ortのリポジトリにはWASMのサンプルがあるので、それをつかってみます。
(↑のコードを使おうとしたらうまくいかなかった)

ortのリポジトリをcloneしましょう。

% gh repo clone pykeio/ort

ort/examples/webassemblyにそのまま動くサンプルがあります。
これを少しだけ変えてChrome Extensionで動かしてみます。

Cargo.tomlにcrateを追加します。

serde-wasm-bindgen = "0.6.5"

webassembly/src/lib.rsを下記のように少し修正。
サンプルではortファイルをloadして使ってます。

use image::{ImageBuffer, Luma, Pixel};
use ort::{ArrayExtensions, Session};
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsValue;
use ndarray::Array4;
use serde_wasm_bindgen::to_value;

static MODEL_BYTES: &[u8] = include_bytes!("mnist.ort");

#[wasm_bindgen]
pub fn classify_image(image_bytes: &[u8]) -> Result<JsValue, JsValue> {
    let session_builder = match Session::builder() {
        Ok(builder) => builder,
        Err(e) => return Err(JsValue::from_str(&format!("Could not create session builder: {:?}", e))),
    };

    let session = match session_builder.commit_from_memory_directly(MODEL_BYTES) {
        Ok(s) => s,
        Err(e) => return Err(JsValue::from_str(&format!("Could not read model from memory: {:?}", e))),
    };

    let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = match image::load_from_memory(image_bytes) {
        Ok(img) => img.to_luma8(),
        Err(e) => return Err(JsValue::from_str(&format!("Could not load image from memory: {:?}", e))),
    };

    let array = Array4::from_shape_fn((1, 1, 28, 28), |(_, _, j, i)| {
        let pixel = image_buffer.get_pixel(i as u32, j as u32);
        let channels = pixel.channels();
        (channels[0] as f32) / 255.0
    });

    let inputs = match ort::inputs![array] {
        Ok(i) => i,
        Err(e) => return Err(JsValue::from_str(&format!("Error creating inputs: {:?}", e))),
    };

    let outputs = match session.run(inputs) {
        Ok(o) => o,
        Err(e) => return Err(JsValue::from_str(&format!("Error during inference: {:?}", e))),
    };

    let probabilities: Vec<f32> = match outputs[0].try_extract_tensor() {
        Ok(tensor) => tensor.softmax(ndarray::Axis(1)).iter().copied().collect(),
        Err(e) => return Err(JsValue::from_str(&format!("Error extracting tensor: {:?}", e))),
    };

    // 確率をパーセンテージ形式で小数点第10位までフォーマット
    let formatted_probabilities: Vec<String> = probabilities.iter().map(|&x| format!("{:.10}%", x * 100.0)).collect();

    Ok(to_value(&formatted_probabilities).map_err(|e| JsValue::from_str(&format!("Error serializing output: {:?}", e)))?)
}

#[cfg(test)]
mod tests {
    use super::*;
    use wasm_bindgen_test::console_log;
    use wasm_bindgen_test::wasm_bindgen_test;
    use serde_wasm_bindgen::from_value;

    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);

    #[wasm_bindgen_test]
    fn run_test() {
        use tracing::Level;
        use tracing_subscriber::fmt;
        use tracing_subscriber_wasm::MakeConsoleWriter;

        #[cfg(target_arch = "wasm32")]
        ort::wasm::initialize();

        fmt()
            .with_ansi(false)
            .with_max_level(Level::DEBUG)
            .with_writer(MakeConsoleWriter::default().map_trace_level_to(Level::DEBUG))
            .without_time()
            .init();

        std::panic::set_hook(Box::new(console_error_panic_hook::hook));

        let image_bytes: &[u8] = include_bytes!("../../../tests/data/mnist_5.jpg");
        let result = classify_image(image_bytes).unwrap();

        // JsValueをVec<String>に変換
        let formatted_probabilities: Vec<String> = from_value(result).unwrap();
        console_log!("Probabilities: {:?}", formatted_probabilities);
    }
}

テストしてみる

wasmpackでtestできます。
必要なパッケージをインストールします。

% brew install chromedriver

このままだとテスト実行時にchromedriverが起動しないので、
ctrlを押しながらクリックして警告ダイアログがでないようにします。
そしてテスト実行。動いてます。

% wasm-pack test --headless --chrome
[INFO]: 🎯  Checking for the Wasm target...
Running headless tests in Chrome on `http://127.0.0.1:56986/`
Try find `webdriver.json` for configure browser's capabilities:
Not found
running 1 test

test ortwasm::tests::run_test ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 filtered out
・・・

buildコマンドでWASMを作成します。
成功するとpkgディレクトリにwasmやjsができてます。

% wasm-pack build --release --target web

% ls -l pkg/
total 13696
-rw-r--r--@ 1 2903 Jun  5 21:10 ortwasm.d.ts
-rw-r--r--@ 1 9310 Jun  5 21:10 ortwasm.js
-rw-r--r--@ 1 6985877 Jun  5 21:10 ortwasm_bg.wasm
-rw-r--r--@ 1 056 Jun  5 21:10 ortwasm_bg.wasm.d.ts
-rw-r--r--@ 1 231 Jun  5 21:10 package.json

適当なChrome Extensionを作成

数値を書いてincludeしたWASMで推論するサンプルを作ります。
まずはextensionディレクトリを作成して、
さきほどのpkgディレクトリをコピーしておきます。

あとはExtensionのコードを作成しましょう。
GenAIで「フリーハンドでキャンバスに数値書いてjpgにして、
そのデータをwasmに渡すChrome Extension作って」
と言ったらほとんど生成してくれます。

extension/manifest.jsonは下記。
WASMを実行するためにcontent_security_policyが必要です。

{
    "manifest_version": 3,
    "name": "MNIST WASM Chrome Extension",
    "version": "1.0",
    "permissions": ["storage"],
    "action": {
      "default_popup": "popup.html",
      "default_icon": {
        "16": "images/icon16.png"
      }
    },
    "background": {
      "service_worker": "background.js"
    },
    "content_security_policy": {
        "extension_pages": "script-src 'self' 'wasm-unsafe-eval'"
      }
}

extension/popup.htmlです。
scriptタグの「type="module"」を指定して
Javascriptモジュールを使用します。

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Draw and Predict</title>
    <style>
        #canvas {
            border: 1px solid black;
        }
    </style>
</head>
<body>
    <h1>Draw a Digit</h1>
    <canvas id="canvas" width="280" height="280"></canvas>
    <br>
    <button id="predictButton">Predict</button>
    <pre id="result"></pre>
    <script type="module" src="popup.js"></script>
</body>
</html>

extension/popup.jsは下記のようになっています。(抜粋)
Rustコードで指定したclassify_imageをimportしてます。

import init, { classify_image } from './pkg/ortwasm.js';

document.addEventListener('DOMContentLoaded', () => {
  init('./pkg/ortwasm_bg.wasm').then(() => {

    //キャンパス描画処理
    ・・・・・・・・・・・

    // Predictボタンが押されたときの処理
    predictButton.addEventListener('click', () => {
        //キャンパスをJPEGに変換など
        const resizedCanvas = document.createElement('canvas');
        ・・・・
        resizedCanvas.toBlob((blob) => {
            const reader = new FileReader();
            reader.onloadend = () => {
                const arrayBuffer = reader.result;
                const uint8Array = new Uint8Array(arrayBuffer);
                // WASMのMNISTモデルに送信
                predictDigit(uint8Array);
            };
            reader.readAsArrayBuffer(blob);
        }, 'image/jpeg'); // JPEG形式で保存
    });

    // WASMのMNISTモデルに送信する関数
    async function predictDigit(imageData) {
        try {
            const result = await classify_image(imageData);
            document.getElementById('result').textContent = JSON.stringify(result, null, 2);
        } catch (error) {
            console.error('Error during prediction:', error);
        }
    }
  }).catch(console.error);
});

ちなみに、テスト時に書いたキャンバスをjpgとしてダウンロードしたかったとき、
↓の関数実行したらそのままダウンロードできて便利だった。

// ローカルにJPEG画像を保存する関数
function saveImage(blob) {
    const url = URL.createObjectURL(blob);
    const a = document.createElement('a');
    a.href = url;
    a.download = 'draw_image.jpg';
    document.body.appendChild(a);
    a.click();
    document.body.removeChild(a);
    URL.revokeObjectURL(url);
}

chrome://extensions/で、「パッケージ化されたいない拡張機能を読み込む」
を押してextensionディレクトリを指定してインストールします。

実行してみると↓みたいな感じです。ボタンを押すと推論結果が表示されてます。

mnist

Summary

今回はONNXランタイム用Rustライブラリortをつかってみました。
onnxがそのまま使えるのは便利ですし、WASMで使えるのも良いです。

References