Amazon SageMakerでFactorization Machinesを使ってみる

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、小澤です。

当エントリではAmazon SageMakerの組み込みアルゴリズムの1つである「Factorization Machines」についての解説を書かせていただきます。

目次

Factorization Machinesとは

Factorization Machinesはレコメンドシステムでよく利用されるアルゴリズムです。 レコメンドシステムはよく、「あるユーザAと購入傾向を持つ別なユーザBにユーザAが買った商品をオススメする」や「ある商品Xと商品Yを購入する人の傾向が似ているから商品Aを買った人に商品Bをオススメする」など、ユーザベースだったりアイテムベースだったりといった説明で解説されます。

Factorization Machinesではこういったユーザや商品などの他に例えば購入した時間帯だったりとさらに他の要素を加えることができます。 これら様々な要素を"交互作用"として取り入れることが可能な仕組みとなっているので、複数の要素間の組み合わせによる影響を加味したモデルを作成できます。

(Factorization Machines with libFMより)

Exampleを実行してみる

SageMakerのExamplesにあるノートブックでどのような実装が必要になるのか見ていきましょう。 先ほどはFactorization Machinesは"レコメンドでよく利用する"と書きましたが、今回利用するExampleは二値分類を行うもののようです。

利用するExample

今回利用するExampleはノートブックの上部メニューから SageMaker Examples > Introduction to Amazon algorithms > factorization_machines_mnist.ipynb とたどったものになります。

データの準備

では、ここから実際にExampleを動かしてみましょう。

最初の方でやっていることは、いつもの流れです。

まずは、S3のバケットの設定やIAMロールの取得を行います。

bucket = '<your_s3_bucket_name_here>'
prefix = 'sagemaker/DEMO-fm-mnist'
 
# Define IAM role
import boto3
import re
from sagemaker import get_execution_role

role = get_execution_role()

bucketは利用する環境に合わせて書き換えてください。 また、データやモデルの保存先を変更したい場合はprefixも合わせて書き換えてください。

続いては、データのダウンロードを行なっています。

%%time
import pickle, gzip, numpy, urllib.request, json

# Load the dataset
urllib.request.urlretrieve("http://deeplearning.net/data/mnist/mnist.pkl.gz", "mnist.pkl.gz")
with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')

利用するデータはおなじみの手書き数字認識に利用されるmnistです。 mnistは手書きで書かれた0から9の数字の画像のデータセットです。 各画像は28x28ピクセルのグレースケールで表現される画像となっています。

実際どういった画像なのかは、次のコードで確認しています。

%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (2,10)


def show_digit(img, caption='', subplot=None):
    if subplot==None:
        _,(subplot)=plt.subplots(1,1)
    imgr=img.reshape((28,28))
    subplot.axis('off')
    subplot.imshow(imgr, cmap='gray')
    plt.title(caption)

show_digit(train_set[0][30], 'This is a {}'.format(train_set[1][30]))

実行結果として以下のような画像が表示されます。

今回行うのは二値分類となります。 mnistのデータに対して、書かれている数字が"0"と"それ以外"になるように正解ラベルを変更したのち、それをRecordIO形式に変換します

import io
import numpy as np
import sagemaker.amazon.common as smac

vectors = np.array([t.tolist() for t in train_set[0]]).astype('float32')
labels = np.where(np.array([t.tolist() for t in train_set[1]]) == 0, 1.0, 0.0).astype('float32')

buf = io.BytesIO()
smac.write_numpy_to_dense_tensor(buf, vectors, labels)
buf.seek(0)

あとはこのデータを学習の際に利用するためにS3に出力して前処理は完了です。

import boto3
import os

key = 'recordio-pb-data'
boto3.resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train', key)).upload_fileobj(buf)
s3_train_data = 's3://{}/{}/train/{}'.format(bucket, prefix, key)
print('uploaded training data location: {}'.format(s3_train_data))

学習の実行

では、いよいよ学習を行います。

まずは学習に利用するコンテナの設定を行います。

from sagemaker.amazon.amazon_estimator import get_image_uri
container = get_image_uri(boto3.Session().region_name, 'factorization-machines')

このget_image_uri関数、実は今まで私も知らなかったのですが、リージョンと手法を指定すればコンテナイメージのパスが取得できるようです。 今までドキュメントのパラメータリストを参照していたので、これでだいぶ楽ができそうですw

これによって以下のような文字列が取得できます。

'351501993468.dkr.ecr.ap-northeast-1.amazonaws.com/factorization-machines:1'

バージョンが"latest"(最新版)ではではなく"1"(安定版)になっている点だけ気にかけておけば良さそうです。

学習の処理はいつものように Estimatorインスタンスの作成 -> ハイパーパラメータの設定 -> fit関数の呼び出し という流れになります。

import boto3
import sagemaker

sess = sagemaker.Session()

fm = sagemaker.estimator.Estimator(container,
                                   role, 
                                   train_instance_count=1, 
                                   train_instance_type='ml.c4.xlarge',
                                   output_path=output_location,
                                   sagemaker_session=sess)
fm.set_hyperparameters(feature_dim=784,
                      predictor_type='binary_classifier',
                      mini_batch_size=200,
                      num_factors=10)

fm.fit({'train': s3_train_data})

設定しているパラメータの値などの詳細はドキュメントをご確認ください。

エンドポイント作成と評価

学習の処理が完了したのちはまずエンドポイントの作成を行います。 これも記述する内容は他の手法と同じです。 使い方が統一されているので、すごく便利です。

fm_predictor = fm.deploy(initial_instance_count=1,
                         instance_type='ml.m4.xlarge')

このエンドポイントに対してデータを投げて予測結果を見てみましょう。 まずは、レスポンスのフォーマット指定やシリアライザに関する設定を行なっています。

import json
from sagemaker.predictor import json_deserializer

def fm_serializer(data):
    js = {'instances': []}
    for row in data:
        js['instances'].append({'features': row.tolist()})
    return json.dumps(js)

fm_predictor.content_type = 'application/json'
fm_predictor.serializer = fm_serializer
fm_predictor.deserializer = json_deserializer

続いて、1件データを投げて見てどのような結果が返ってくるか確認します。

import numpy as np

predictions = []
for array in np.array_split(test_set[0], 100):
    result = fm_predictor.predict(array)
    predictions += [r['predicted_label'] for r in result['predictions']]

predictions = np.array(predictions)

以下のような結果が返ってきます。

{'predictions': [{'score': 0.0, 'predicted_label': 0.0}]}

ちょっとまぎらわしいですが、書いてある数字が0であれば1、それ以外であれば0となるように学習しているので、 先ほど確認した際に3だったこのデータは正しく予測されているようです。

では、テストデータを使ってどのくらいの予測性能がでているのか見てみましょう。 今回は、Confusion Matrixを作成しています。

import pandas as pd

pd.crosstab(np.where(test_set[1] == 0, 1, 0), predictions, rownames=['actuals'], colnames=['predictions'])

以下のような結果が得られます。

横軸が予測値、縦軸が正解の値となっています。 クロスした部分がそれぞれのカウントです。

例えば、正解が0で予測も0だったものは8855件ということになります。 非対角成分の件数が少なければ少ないほどいい結果ということになります。

最後にエンドポイントを削除しておきましょう。

import sagemaker

sagemaker.Session().delete_endpoint(fm_predictor.endpoint)

おわりに

今回はSageMakerの組み込みアルゴリズムのうち、Factorization MachinesについてExampleの流れを追ってみました。 対象が簡単なものだったため、他のアルゴリズムと比較して何が優れているのかこれだけはピンとこなかった方もいるかもしれません。 ぜひ、皆さんの持ってるデータでも試してみてください。