Amazon SageMakerで作成したFactorization Machinesの分類モデルをローカルで推論させる

SageMakerには組み込みアルゴリズムの1つとしてFactorization Machines(FM)が用意されています。

今回はSageMakerで学習させたFMモデルを使ったローカルでの推論を試してみたので、その内容を紹介します。

やってみる

今回の焦点はローカルでのFMモデルの推論ですので、SageMakerでのモデルの学習部分は割愛します。 データセットはMNISTを使用し、モデルは以下のノートブックに則ってSageMakerで作成したFMの二値分類モデルを使用します。

事前準備

ライブラリを読み込み、モデルアーティファクトを保存している場所を定義しておきます。

import boto3
import json
import pickle
import gzip
import urllib.request
import mxnet as mx
import numpy as np
from os import path

session = boto3.Session()

bucket = 'hogehoge'
model_path = 'sagemaker/DEMO-fm-mnist/output/factorization-machines-2019-08-15-09-38-15-958/output/model.tar.gz'
model_file_name = path.basename(model_path)

データ準備

MNISTの手書き数字のデータセットをダウンロードし、展開します。 データセットには学習と検証、テスト用がありますが、今回は検証用を使って進めます。
元々は0~9の数字のどれかを予測する分類問題ですが、今回のモデルは0かどうかの二値分類のモデルなので、ラベルを書き換えます。

urllib.request.urlretrieve("http://deeplearning.net/data/mnist/mnist.pkl.gz", "mnist.pkl.gz")
with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
vectors, labels = valid_set

# 0かどうかの二値分類にする
labels = (labels == 0).astype('float32')

モデルのダウンロードと展開

S3にあるモデルアーティファクトをダウンロード&展開し、MXNetで読み込める形に名前を変更します。

# FMモデルをダウンロードして展開
session.client('s3').download_file(bucket, model_path, model_file_name)
os.system(f'tar xzvf {model_file_name}')
os.system(f'unzip -o model_algo-1')
os.system(f'mv symbol.json model-symbol.json')
os.system(f'mv params model-0000.params')

モデルの読み込み

モデルのメタ情報がjsonファイルで提供されているので、読み込んでみます。

with open('meta.json', 'r') as f:
    meta = json.load(f)
print(meta)

出力

{'label_names': ['out_label'], 'training_parameters': {'factors_lr': '0.0001', 'linear_init_sigma': '0.01', 'epochs': 1, 'feature_dim': '784', 'num_factors': '10', '_wd': '1.0', '_num_kv_servers': 'auto', 'use_bias': 'true', 'factors_init_sigma': '0.001', '_log_level': 'info', 'bias_init_method': 'normal', 'linear_init_method': 'normal', 'linear_lr': '0.001', 'factors_init_method': 'normal', '_tuning_objective_metric': '', 'bias_wd': '0.01', 'use_linear': 'true', 'bias_lr': '0.1', 'mini_batch_size': '200', '_use_full_symbolic': 'true', 'batch_metrics_publish_interval': '500', 'predictor_type': 'binary_classifier', 'bias_init_sigma': '0.01', '_num_gpus': 'auto', '_data_format': 'record', 'factors_wd': '0.00001', 'linear_wd': '0.001', '_kvstore': 'auto', '_learning_rate': '1.0', '_optimizer': 'adam'}, 'version': 1, 'epoch_number': 1}

学習時のパラメータなどを確認することができます。

次に、MXNetでモデルを読み込みます。

m = mx.module.Module.load('./model', epoch=0, label_names=meta['label_names'])

データをMXNetで扱いやすい形に変換し、そのデータ形式をモデルに設定します。

# MXNetの対応するデータに変換
validation_iter = mx.io.NDArrayIter(vectors, labels, label_name=meta['label_names'][0])

# データ形式を設定
m.bind(data_shapes=validation_iter.provide_data, label_shapes=validation_iter.provide_label)

推論

検証用のデータを推論し、推論結果をndarray形式に変換したものを表示します。

pred = m.predict(eval_data=validation_iter)
print(pred.asnumpy())

出力

[[0.]
 [0.]
 [0.]
 ...
 [0.]
 [0.]
 [0.]]

※ 0が並んでいますが、出力は0から1の実数です。

さいごに

SageMakerで学習させたFMモデルを使ってローカルで推論する方法について紹介しました。誰かの役に立てば嬉しいです。

参考