Introduction
以前、BurnというRustの機械学習フレームワークで
ONNXファイルを変換して使うという記事を書きました。
問題なく変換して推論までできましたが、onnxファイルをそのまま使いたいケースもあります。
というわけで、今回はONNX RuntimeのRustラッパー「ort」を使ってみます。
また、wasmpackを使ってWASMにしてChrome Extensionから使ってみます。
[補足] ONNX?
ONNXは、さまざまな機械学習フレームワーク間で使用できる共通フォーマットです。
これを使うことにより、PytorchでトレーニングしたモデルをTensorFlowで使う
みたいなことが容易にできます。
ort?
ortは、ONNXランタイム用のRustバインディングです。
ここで紹介されていますが、ortとONNX Runtimeを併用することで、
さまざまなMLモデル (YOLOv8、BERT、LLaMAなど) を(ほぼ)すべてのハードウェア上で実行でき、
さらに多くのケースでPyTorchよりも高速に実行させることができます。
(機械学習モデルをONNXグラフに変換することで最適化も可能となる)
Environment
- MacBook Pro (14-inch, M3, 2023)
- OS : MacOS 14.5
- Rust : 1.78.0
- wasm-pack 0.12.1
- gh : 2.49.2
Try
では、ortをつかってMNISTの数値認識をやってみましょう。 まずはcargoでプロジェクトを作成します。
% cargo new ort-rs
% cd ort-rs
Cargo.tomlは↓のようになってます。
cargo addでortいれるとversion1が入るので注意。
[dependencies]
image = "0.25.1"
ndarray = "0.15.6"
ort = "2.0.0-rc.2"
tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] }
このあたりからmnistのonnxファイルを持ってきます。
あとはテストで確認するための数値画像はこのへんとかで用意。
それらのファイルはプロジェクトroot(ort-rsの下)においておきます。
main.rsはこんな感じです。
ort::Sessionでmnist.onnxをloadして、
ndarrayで画像を変換後にSession::runで推論を実行します。
use ort::{GraphOptimizationLevel, Session, Value};
use std::error::Error;
use ndarray::{Array, Ix4, ArrayD};
use std::collections::HashMap;
use image::io::Reader as ImageReader;
/// メイン関数
fn main() -> Result<(), Box<dyn Error>> {
// ログの初期化
tracing_subscriber::fmt::init();
// モデルのセッションを作成
let model = create_session(include_bytes!("../mnist.onnx"))?;
// 画像をロードして前処理を行う(8の画像)
let img_array = preprocess_image(include_bytes!("../mnist_8.jpg"))?;
// 推論を実行して結果を表示
run_inference(&model, img_array)?;
Ok(())
}
fn create_session(model_data: &[u8]) -> Result<Session, Box<dyn Error>> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_memory(model_data)?;
Ok(session)
}
fn preprocess_image(img_data: &[u8]) -> Result<Array<f32, Ix4>, Box<dyn Error>> {
// 画像をロード
let img = ImageReader::new(std::io::Cursor::new(img_data))
.with_guessed_format()?
.decode()?
.to_luma8();
// 画像を28x28にリサイズ
let img = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Nearest);
// 画像データを正規化してndarrayに変換
let img_array = Array::from_shape_vec(
(1, 1, 28, 28),
img.iter().map(|&p| p as f32 / 255.0).collect(),
)?;
Ok(img_array)
}
fn run_inference(model: &Session, img_array: Array<f32, Ix4>) -> Result<(), Box<dyn Error>> {
// 入力データをHashMapに格納
let input_tensor = Value::from_array(img_array)?;
let mut inputs = HashMap::new();
inputs.insert("Input3", input_tensor);
// 推論を実行
let outputs = model.run(inputs)?;
// 結果を表示
for (name, tensor) in outputs.iter() {
println!("Output {}: {:?}", name, tensor);
let array_view: ndarray::ArrayViewD<f32> = tensor.try_extract_tensor()?;
let array: ArrayD<f32> = array_view.to_owned();
println!("Tensor values: {:?}", array);
// 最も高い値を持つインデックスを見つける
let max_index = array.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(index, _)| index)
.unwrap();
println!("Predicted digit: {}", max_index);
}
Ok(())
}
ちなみに、入力パラメータ名って何つかえばいいかわからなかったので、
GenAIに聞いたら↓で調べろといわれた。
import onnx
# ONNXモデルをロード
model = onnx.load("mnist.onnx")
# モデルの入力名を表示
for input in model.graph.input:
print(input.name)
実行するとこんな感じです。
一応、画像はちゃんと判定されてますね。
% cargo run
Output Plus214_Output_0: Value { ・・・ }
Tensor values: [[1.1330107, -5.077395, 8.586635, 4.278495, -7.7753954, 1.1661655, -4.3016477, -10.987418, 14.0324135, 1.9482136]], shape=[1, 10], strides=[10, 1], layout=CFcf (0xf), dynamic ndim=2
Predicted digit: 8
Convert to WASM & Use in Chrome Extension
では次に、mnistプログラムをWASM化して
Chrome Extensionで使ってみます。
onnxのWASM化はまだexperimentalとのことですが、
ortのリポジトリにはWASMのサンプルがあるので、それをつかってみます。
(↑のコードを使おうとしたらうまくいかなかった)
ortのリポジトリをcloneしましょう。
% gh repo clone pykeio/ort
ort/examples/webassemblyにそのまま動くサンプルがあります。
これを少しだけ変えてChrome Extensionで動かしてみます。
Cargo.tomlにcrateを追加します。
serde-wasm-bindgen = "0.6.5"
webassembly/src/lib.rsを下記のように少し修正。
サンプルではortファイルをloadして使ってます。
use image::{ImageBuffer, Luma, Pixel};
use ort::{ArrayExtensions, Session};
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsValue;
use ndarray::Array4;
use serde_wasm_bindgen::to_value;
static MODEL_BYTES: &[u8] = include_bytes!("mnist.ort");
#[wasm_bindgen]
pub fn classify_image(image_bytes: &[u8]) -> Result<JsValue, JsValue> {
let session_builder = match Session::builder() {
Ok(builder) => builder,
Err(e) => return Err(JsValue::from_str(&format!("Could not create session builder: {:?}", e))),
};
let session = match session_builder.commit_from_memory_directly(MODEL_BYTES) {
Ok(s) => s,
Err(e) => return Err(JsValue::from_str(&format!("Could not read model from memory: {:?}", e))),
};
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = match image::load_from_memory(image_bytes) {
Ok(img) => img.to_luma8(),
Err(e) => return Err(JsValue::from_str(&format!("Could not load image from memory: {:?}", e))),
};
let array = Array4::from_shape_fn((1, 1, 28, 28), |(_, _, j, i)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
(channels[0] as f32) / 255.0
});
let inputs = match ort::inputs![array] {
Ok(i) => i,
Err(e) => return Err(JsValue::from_str(&format!("Error creating inputs: {:?}", e))),
};
let outputs = match session.run(inputs) {
Ok(o) => o,
Err(e) => return Err(JsValue::from_str(&format!("Error during inference: {:?}", e))),
};
let probabilities: Vec<f32> = match outputs[0].try_extract_tensor() {
Ok(tensor) => tensor.softmax(ndarray::Axis(1)).iter().copied().collect(),
Err(e) => return Err(JsValue::from_str(&format!("Error extracting tensor: {:?}", e))),
};
// 確率をパーセンテージ形式で小数点第10位までフォーマット
let formatted_probabilities: Vec<String> = probabilities.iter().map(|&x| format!("{:.10}%", x * 100.0)).collect();
Ok(to_value(&formatted_probabilities).map_err(|e| JsValue::from_str(&format!("Error serializing output: {:?}", e)))?)
}
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::console_log;
use wasm_bindgen_test::wasm_bindgen_test;
use serde_wasm_bindgen::from_value;
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn run_test() {
use tracing::Level;
use tracing_subscriber::fmt;
use tracing_subscriber_wasm::MakeConsoleWriter;
#[cfg(target_arch = "wasm32")]
ort::wasm::initialize();
fmt()
.with_ansi(false)
.with_max_level(Level::DEBUG)
.with_writer(MakeConsoleWriter::default().map_trace_level_to(Level::DEBUG))
.without_time()
.init();
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
let image_bytes: &[u8] = include_bytes!("../../../tests/data/mnist_5.jpg");
let result = classify_image(image_bytes).unwrap();
// JsValueをVec<String>に変換
let formatted_probabilities: Vec<String> = from_value(result).unwrap();
console_log!("Probabilities: {:?}", formatted_probabilities);
}
}
テストしてみる
wasmpackでtestできます。
必要なパッケージをインストールします。
% brew install chromedriver
このままだとテスト実行時にchromedriverが起動しないので、
ctrlを押しながらクリックして警告ダイアログがでないようにします。
そしてテスト実行。動いてます。
% wasm-pack test --headless --chrome
[INFO]: 🎯 Checking for the Wasm target...
Running headless tests in Chrome on `http://127.0.0.1:56986/`
Try find `webdriver.json` for configure browser's capabilities:
Not found
running 1 test
test ortwasm::tests::run_test ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 filtered out
・・・
buildコマンドでWASMを作成します。
成功するとpkgディレクトリにwasmやjsができてます。
% wasm-pack build --release --target web
% ls -l pkg/
total 13696
-rw-r--r--@ 1 2903 Jun 5 21:10 ortwasm.d.ts
-rw-r--r--@ 1 9310 Jun 5 21:10 ortwasm.js
-rw-r--r--@ 1 6985877 Jun 5 21:10 ortwasm_bg.wasm
-rw-r--r--@ 1 056 Jun 5 21:10 ortwasm_bg.wasm.d.ts
-rw-r--r--@ 1 231 Jun 5 21:10 package.json
適当なChrome Extensionを作成
数値を書いてincludeしたWASMで推論するサンプルを作ります。
まずはextensionディレクトリを作成して、
さきほどのpkgディレクトリをコピーしておきます。
あとはExtensionのコードを作成しましょう。
GenAIで「フリーハンドでキャンバスに数値書いてjpgにして、
そのデータをwasmに渡すChrome Extension作って」
と言ったらほとんど生成してくれます。
extension/manifest.jsonは下記。
WASMを実行するためにcontent_security_policyが必要です。
{
"manifest_version": 3,
"name": "MNIST WASM Chrome Extension",
"version": "1.0",
"permissions": ["storage"],
"action": {
"default_popup": "popup.html",
"default_icon": {
"16": "images/icon16.png"
}
},
"background": {
"service_worker": "background.js"
},
"content_security_policy": {
"extension_pages": "script-src 'self' 'wasm-unsafe-eval'"
}
}
extension/popup.htmlです。
scriptタグの「type="module"」を指定して
Javascriptモジュールを使用します。
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Draw and Predict</title>
<style>
#canvas {
border: 1px solid black;
}
</style>
</head>
<body>
<h1>Draw a Digit</h1>
<canvas id="canvas" width="280" height="280"></canvas>
<br>
<button id="predictButton">Predict</button>
<pre id="result"></pre>
<script type="module" src="popup.js"></script>
</body>
</html>
extension/popup.jsは下記のようになっています。(抜粋)
Rustコードで指定したclassify_imageをimportしてます。
import init, { classify_image } from './pkg/ortwasm.js';
document.addEventListener('DOMContentLoaded', () => {
init('./pkg/ortwasm_bg.wasm').then(() => {
//キャンパス描画処理
・・・・・・・・・・・
// Predictボタンが押されたときの処理
predictButton.addEventListener('click', () => {
//キャンパスをJPEGに変換など
const resizedCanvas = document.createElement('canvas');
・・・・
resizedCanvas.toBlob((blob) => {
const reader = new FileReader();
reader.onloadend = () => {
const arrayBuffer = reader.result;
const uint8Array = new Uint8Array(arrayBuffer);
// WASMのMNISTモデルに送信
predictDigit(uint8Array);
};
reader.readAsArrayBuffer(blob);
}, 'image/jpeg'); // JPEG形式で保存
});
// WASMのMNISTモデルに送信する関数
async function predictDigit(imageData) {
try {
const result = await classify_image(imageData);
document.getElementById('result').textContent = JSON.stringify(result, null, 2);
} catch (error) {
console.error('Error during prediction:', error);
}
}
}).catch(console.error);
});
ちなみに、テスト時に書いたキャンバスをjpgとしてダウンロードしたかったとき、
↓の関数実行したらそのままダウンロードできて便利だった。
// ローカルにJPEG画像を保存する関数
function saveImage(blob) {
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'draw_image.jpg';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
chrome://extensions/で、「パッケージ化されたいない拡張機能を読み込む」
を押してextensionディレクトリを指定してインストールします。
実行してみると↓みたいな感じです。ボタンを押すと推論結果が表示されてます。
Summary
今回はONNXランタイム用Rustライブラリortをつかってみました。
onnxがそのまま使えるのは便利ですし、WASMで使えるのも良いです。