MomentoとTensorFlow.jsで推論キャッシュパターン

2023.01.12

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

Introduction

DevelopersIOで度々紹介している、
クラウドネイティブな高速キャッシュサービスMomento Serverless Cache(以下Momento)。
↓のような特徴をもっています。 

  • セットアップがとても簡単
  • プロビジョニングの必要がない(自動でうまくやってくれる)
  • 料金はデータの転送量($0.15/GB)のみ。月50GBまでは無料

※Momento自体の情報についてはここを確認してみてください

今回はJavaScript版のMomento SDKをつかって、
推論キャッシュパターンの例を紹介します。

Inference cache?

推論キャッシュパターンとは、
機械学習モデルを用いて推論した結果をキャッシュに格納することで、
同じパラメータの推論をせずにキャッシュから結果を返すことで
処理の高速化やサーバ負荷を軽減します。

同じパラメータの推論リクエストが発生する、
推論結果が実行した時間に依存しないなどいくつか制約がありますが、
これらをクリアすれば高速な処理が可能になります。

処理の流れは次のようになります。

  1. クライアントから推論処理をリクエスト
  2. パラメータをキーにしてMomentoを検索
  3. 結果があればそれを返す
  4. 結果がなければ推論実行し、結果をMomentoへ保存してから返す

Environment

今回試した環境は以下のとおりです。

  • MacBook Pro (13-inch, M1, 2020)
  • OS : MacOS 12.4
  • Node : v18.2.0

Setup

今回は機械学習ライブラリにTEnsorFlow.js、キャッシュにMomentoを使います。
関数の結果を予測するモデルを作成し、
パラメータを渡して推論を実行、キャッシュにデータがあればそれを返します。

では各種環境をセットアップしましょう。

Momentoのセットアップ

最初はキャッシュ機能であるMomentoのセットアップです。
このあたりを参考に認証トークンを取得しましょう。
あとでjavascriptのprocess.envから値を取得するので、
コンソールでexportしておきます。

% export MOMENTO_AUTH_TOKEN=<認証トークン>

Tensorflow.jsでモデルの作成

次はTensorFlow.jsをつかって機械学習モデルを作成します。

Node用のTensorflowをインストールします。

% mkdir tensorflow-example && cd tensorflow-example
% npm install @tensorflow/tfjs-node --save

今回は単純な関数(y=2x-2)を学習させてモデルを作成します。
package.jsonに「"type": "module"」を追加して、
↓のmjsファイル(save-model.mjs)を作成します。

//save-model.mjs

import * as tf from "@tensorflow/tfjs";
import * as tfn from "@tensorflow/tfjs-node";

/** y = 2x - 2の関数を学習させる **/

// 入力データと期待する出力データをそれぞれtensorとして定義
const xs = tf.tensor([-1, 0, 1, 2, 3, 4]);
const ys = tf.tensor([-4, -2, 0, 2, 4, 6]);

// シーケンシャルモデル
const model = tf.sequential();

//レイヤー定義
model.add(
  tf.layers.dense({
    inputShape: 1, //値は1つ-入力tensor
    units: 1, //ニューロン1つ-出力tensor
  })
);

//オプティマイザと損失関数を設定
model.compile({
  optimizer: "sgd",
  loss: "meanSquaredError",
});

//モデル構造出力
model.summary();

//訓練開始(エポック数=300)
let history = await model.fit(xs, ys, { epochs: 300 });

//推論
let x = 10;
const inputTensor = tf.tensor([x]);
const answer = model.predict(inputTensor);
console.log(`x = ${x}, results = ${Math.round(answer.dataSync())}`);

//モデルの保存
await model.save('file://./my-function-model');

//cleanup
tf.dispose([xs, ys, model, answer, inputTensor]);

実行するとmy-function-modelディレクトリが作成されます。
このモデルを動的にLoadして推論処理を行います。

% node save-model.mjs

モデルをLoadする処理はこんな感じです。

//predict.mjs

import * as tf from "@tensorflow/tfjs";
import * as tfn from "@tensorflow/tfjs-node";

//さきほど保存したモデルをload
const model = await tf.loadLayersModel('file://./my-function-model/model.json');

//推論
let x = 5;
const inputTensor = tf.tensor([x]);
const answer = model.predict(inputTensor);
console.log(`x = ${x}, results = ${Math.round(answer.dataSync())}`);
% node predict.mjs
x = 5, results = 8

Implement inference cache

Expressで推論実行

次はExpressをつかって、先程のモデルをつかった推論APIを実装します。
最初に必要パッケージをインストールしましょう。

% mkdir ./express-app && cd express-app
% npm install express @tensorflow/tfjs-node @gomomento/sdk --save

さきほど保存したモデルのディレクトリをもってきます。

% cp -R /path/to/your/my-function-model .

下記内容でserver.jsファイルを作成します。

// server.js
const tf = require('@tensorflow/tfjs')
const tfn = require('@tensorflow/tfjs-node')

const express = require('express')
const app = express()
const port = 3000

var model = null;

//モデルをLoadする
async function load_model() {
  model = await tf.loadLayersModel(
    'file://my-function-model/model.json',
  );  
}

//推論を実行
function predict(x) {
  const inputTensor = tf.tensor([x]);
  const answer = model.predict(inputTensor);
  const return_value = Math.round(answer.dataSync());
  return return_value;
}

app.get('/predict', async (req, res) => {

  //loadしたモデルで推論を実施
  let x = Number(req.query.x);
  let return_value = predict(x);

  let msg = `x = ${x}, results = ${return_value}`;
  res.send(msg)
})

app.listen(port, async () => {
  await load_model();
  console.log(`Example app listening on port ${port}`);
});

作成したプログラムを実行してサーバ起動。

% node server.js
Example app listening on port 3000

http://localhost:3000/predict?x=5
みたいにアクセスすると、推論実行して結果をかえします。
この場合、毎回推論処理を実行することになります。

Momentoをつかって推論キャッシュする

ではMomentoで推論結果をキャッシュするようにしてみましょう。
server.jsを少し修正します。

キャッシュデータを登録する関数と取得する関数をそれぞれ作成します。
(set/getのレスポンス形式がv0.16で変更されました)

var client = new momento.SimpleCacheClient(process.env.MOMENTO_AUTH_TOKEN, 10);

async function setMomentoCache(cache_name,key,value) {
  console.log("setMomentoCache");
  const setResult = await client.set(cache_name, String(key),String(value));
  if (setResult instanceof momento.CacheSet.Success) {
    console.log(
      'Set Cache Data successfully with value ' + setResult.valueString()
    );
  } else if (setResult instanceof momento.CacheSet.Error) {
    console.log('Error Set key: ' + setResult.message());
  }
}

async function getMomentoCache(cache_name,key){
  const getResult = await client.get(cache_name, String(key));
  if (getResult instanceof momento.CacheGet.Hit) {
    console.log(`Cache hit: ${String(getResult.valueString())}`);
  } else if (getResult instanceof momento.CacheGet.Miss) {
    console.log('Cache miss');
    return null;
  } else if (getResult instanceof momento.CacheGet.Error) {
    console.log(`Error: ${getResult.message()}`);
    return null;
  }

  return getResult.valueString();
}

ここではクエリパラメータの値をそのままキーにしています。
もし複数のパラメータがあれば、
それらのパラメータを任意の記号(_とか)で
つないだ文字列をキーにすることになるかと思います。

次に/predictにアクセスされたときの処理を記述します。
クエリパラメータ(xの値)をキーにMomentoからキャッシュデータを取得します。
キャッシュがあればそれを返し、なければ推論を実行して結果をMomentoに保存します。

app.get('/predict', async (req, res) => {

  var msg = null;
  let x = Number(req.query.x);

  let cache_data = await getMomentoCache('default-cache',x);

  if(cache_data) {
    //キャッシュにデータがあればそれを返す
    msg = `x = ${x}, results = ${cache_data}`;
  } else {
    //キャッシュにデータがなければ推論してキャッシュに保存
    let return_value = predict(x);
    msg = `x = ${x}, results = ${return_value}`;

    await setMomentoCache('default-cache',x,return_value);
  }

  res.send(msg)
})

server.jsの全文はこちらです。

//server.js
const tf = require('@tensorflow/tfjs')
const tfn = require('@tensorflow/tfjs-node')
const momento = require('@gomomento/sdk');

const express = require('express')
const app = express()
const port = 3000

var model = null;
var client = new momento.SimpleCacheClient(process.env.MOMENTO_AUTH_TOKEN, 10);


app.get('/predict', async (req, res) => {

  var msg = null;
  let x = Number(req.query.x);

  let cache_data = await getMomentoCache('default-cache',x);

  if(cache_data) {
    //キャッシュにデータがあればそれを返す
    msg = `x = ${x}, results = ${cache_data}`;
  } else {
    //キャッシュにデータがなければ推論してキャッシュに保存
    let return_value = predict(x);
    msg = `x = ${x}, results = ${return_value}`;

    await setMomentoCache('default-cache',x,return_value);
  }

  res.send(msg)
})

app.listen(port, async () => {
  await load_model();
  console.log(`Example app listening on port ${port}`);
});


async function load_model() {
  model = await tf.loadLayersModel(
    'file://my-function-model/model.json',
  );  
}

function predict(x) {
  console.log("Execute Predict!");
  const inputTensor = tf.tensor([x]);
  const answer = model.predict(inputTensor);
  const return_value = Math.round(answer.dataSync());
  return return_value;
}

async function setMomentoCache(cache_name,key,value) {
  console.log("setMomentoCache");
  const setResult = await client.set(cache_name, String(key),String(value));
  //console.log(`setResult:${setResult.text()}`);
  if (setResult instanceof momento.CacheSet.Success) {
    console.log(
      'Set Cache Data successfully with value ' + setResult.valueString()
    );
  } else if (setResult instanceof momento.CacheSet.Error) {
    console.log('Error Set key: ' + setResult.message());
  }
}

async function getMomentoCache(cache_name,key){
  const getResult = await client.get(cache_name, String(key));
  if (getResult instanceof momento.CacheGet.Hit) {
    console.log(`Cache hit: ${String(getResult.valueString())}`);
  } else if (getResult instanceof momento.CacheGet.Miss) {
    console.log('Cache miss');
    return null;
  } else if (getResult instanceof momento.CacheGet.Error) {
    console.log(`Error: ${getResult.message()}`);
    return null;
  }

  return getResult.valueString();
}

再度server.jsを実行してアクセスしてみましょう。
キャッシュがあれば無駄な推論処理をせずにキャッシュデータを返しています。

% node server.js
Example app listening on port 3000

% curl http://localhost:3000/predict?x=5

============= log ===============
Cache miss  #キャッシュがみつからない
Execute Predict! #推論実行
setMomentoCache
Set Cache Data successfully with value 8
=================================

% curl http://localhost:3000/predict?x=5

============= log ===============
Cache hit: 8  #キャッシュがみつかったのでそのまま返す
=================================

Summary

今回はMomentoをつかって推論キャッシュを実装してみました。
推論はけっこうCPU負荷がかかるので、
キャッシュをつかって効率よく処理しましょう。

この記事でキャッシュに使用したMomentoについてのお問い合わせはこちらです。
お気軽にお問い合わせください。

Momentoセミナーのお知らせ

2023年1月13日(金) 16:00からMomentoのセミナーを開催します。
興味があるかたはぜひご参加ください。

References