AmazonSageMakerのXGBoostでMNISTの手書き文字を分類してみた
こんにちは、大澤です。
当エントリではAmazon SageMakerの組み込みアルゴリズムの1つ、「XGBoost」を用いた分類方法についてご紹介していきたいと思います。
「組み込みアルゴリズム」の解説については下記エントリをご参照ください。
目次
概要説明:XGBoostとは
XGBoostとは勾配ブースティングツリーと言う理論のオープンソースの実装で、分類や回帰に使われています。 当エントリは分類を使った例について紹介します。
組み込みアルゴリズム:XGBoostの実践
Amazon SageMaker ExamplesのXGBoostの例を参考にして進めていきます。 基本的には同じ内容ですが、例ではboto3のSageMaker用低レベルAPIを使用しているのに対して、当エントリではSageMakerのPython用SDKを使ったものとなります。
定番のMNISTの手書き数字のデータを元に、書かれている数字の分類するという内容です。
ノートブックの作成
SageMakerのノートブックインスタンスを立ち上げて、 SageMaker Examples ↓ Introduction to Amazon algorithms ↓ xgboost_mnist.ipynb ↓ use でサンプルからノートブックをコピーして、開きます。 ノートブックインスタンスの作成についてはこちらをご参照ください。
環境変数とロールの確認
学習データ等を保存するS3のバケット名と保存オブジェクト名の接頭辞を決めます。
%%time import os import boto3 import re import copy import time from time import gmtime, strftime from sagemaker import get_execution_role role = get_execution_role() region = boto3.Session().region_name bucket='<your_s3_bucket_name_here>' # put your s3 bucket name here, and create s3 bucket prefix = 'sagemaker/DEMO-xgboost-multiclass-classification' # customize to your bucket where you have stored the data bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)
データ取得
MNISTのデータを取得してきます。
%%time import pickle, gzip, numpy, urllib.request, json # Load the dataset urllib.request.urlretrieve("http://deeplearning.net/data/mnist/mnist.pkl.gz", "mnist.pkl.gz") f = gzip.open('mnist.pkl.gz', 'rb') train_set, valid_set, test_set = pickle.load(f, encoding='latin1') f.close()
MNISTのデータをLIBSVM形式に加工して、S3に保存します。 まずはそのための関数を定義します。
%%time import struct import io import boto3 def to_libsvm(f, labels, values): f.write(bytes('\n'.join( ['{} {}'.format(label, ' '.join(['{}:{}'.format(i + 1, el) for i, el in enumerate(vec)])) for label, vec in zip(labels, values)]), 'utf-8')) return f def write_to_s3(fobj, bucket, key): return boto3.Session().resource('s3').Bucket(bucket).Object(key).upload_fileobj(fobj) def get_dataset(): import pickle import gzip with gzip.open('mnist.pkl.gz', 'rb') as f: u = pickle._Unpickler(f) u.encoding = 'latin1' return u.load() def upload_to_s3(partition_name, partition): labels = [t.tolist() for t in partition[1]] vectors = [t.tolist() for t in partition[0]] num_partition = 5 # partition file into 5 parts partition_bound = int(len(labels)/num_partition) for i in range(num_partition): f = io.BytesIO() to_libsvm(f, labels[i*partition_bound:(i+1)*partition_bound], vectors[i*partition_bound:(i+1)*partition_bound]) f.seek(0) key = "{}/{}/examples{}".format(prefix,partition_name,str(i)) url = 's3n://{}/{}'.format(bucket, key) print('Writing to {}'.format(url)) write_to_s3(f, bucket, key) print('Done writing to {}'.format(url)) def download_from_s3(partition_name, number, filename): key = "{}/{}/examples{}".format(prefix,partition_name, number) url = 's3n://{}/{}'.format(bucket, key) print('Reading from {}'.format(url)) s3 = boto3.resource('s3') s3.Bucket(bucket).download_file(key, filename) try: s3.Bucket(bucket).download_file(key, 'mnist.local.test') except botocore.exceptions.ClientError as e: if e.response['Error']['Code'] == "404": print('The object does not exist at {}.'.format(url)) else: raise def convert_data(): train_set, valid_set, test_set = get_dataset() partitions = [('train', train_set), ('validation', valid_set), ('test', test_set)] for partition_name, partition in partitions: print('{}: {} {}'.format(partition_name, partition[0].shape, partition[1].shape)) upload_to_s3(partition_name, partition)
先ほど定義した関数を使って、加工・保存処理を実行します。
%%time convert_data()
学習
XGBoost用のコンテナの名前を取得します。
from sagemaker.amazon.amazon_estimator import get_image_uri container = get_image_uri(boto3.Session().region_name, 'xgboost')
ハイパーパラメータや出力先、学習用コンテナイメージ等の学習に必要な設定を行い、学習処理を実行します。 ハイパーパラメータに関する詳細はドキュメントをご確認ください。
import boto3 import sagemaker sess = sagemaker.Session() xg = sagemaker.estimator.Estimator(container, role, train_instance_count=1, train_instance_type='ml.m4.4xlarge', output_path=bucket_path + "/"+ prefix + "/xgboost", train_max_run = 900, sagemaker_session=sess) xg.set_hyperparameters(max_depth=5, eta=0.2, gamma=4, min_child_weight=6, silent=0, objective= "multi:softmax", num_class=10, num_round=10) train_data = sagemaker.session.s3_input( bucket_path + "/"+ prefix+ '/train/', distribution='FullyReplicated', content_type='libsvm', s3_data_type='S3Prefix') validation_data = sagemaker.session.s3_input( bucket_path + "/"+ prefix+ '/validation/', distribution='FullyReplicated', content_type='libsvm', s3_data_type='S3Prefix') data_channels = {'train': train_data, 'validation': validation_data} xg.fit(inputs=data_channels, logs=True)
モデルの展開
エンドポイントを作成し、先ほど学習したモデルを展開します。
xg_predictor = xg.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')
モデルの確認
エンドポイントへのリクエスト形式を指定します。
xg_predictor.content_type = 'text/x-libsvm'
テスト用にデータをダウンロードし、そこから確認用に1行だけ書き出します。
download_from_s3('test', 0, 'mnist.local.test') # reading the first part file within test !head -1 mnist.local.test > mnist.single.test
エンドポイントにデータを投げて予測結果を受け取って表示してみます。
%%time import json file_name = 'mnist.single.test' #customize to your test file 'mnist.single.test' if use the data above with open(file_name, 'r') as f: payload = f.read() result = xg_predictor.predict(payload).decode('utf-8') print('Predicted label is {}.'.format(result))
分類ができてる感じがしますね。
次はテストデータ全体を分類してみます。 まずは関数定義。
import sys def do_predict(data): payload = '\n'.join(data) result = xg_predictor.predict(payload).decode('utf-8') preds = [float(num) for num in result.split(',')] return preds def batch_predict(data, batch_size): items = len(data) arrs = [] for offset in range(0, items, batch_size): arrs.extend(do_predict(data[offset:min(offset+batch_size, items)])) sys.stdout.write('.') return(arrs)
定義した関数を使って分類処理を実行し、誤差率を表示します。
%%time import json file_name = 'mnist.local.test' with open(file_name, 'r') as f: payload = f.read().strip() labels = [float(line.split(' ')[0]) for line in payload.split('\n')] test_data = payload.split('\n') preds = batch_predict(test_data, 100) print ('\nerror rate=%f' % ( sum(1 for i in range(len(preds)) if preds[i]!=labels[i]) /float(len(preds))))
誤差率が約11%とそれなりの結果が出ました。 分類したデータを俯瞰するために分類結果と正しいラベルの混同行列を表示してみます。
import numpy def error_rate(predictions, labels): """Return the error rate and confusions.""" correct = numpy.sum(predictions == labels) total = predictions.shape[0] error = 100.0 - (100 * float(correct) / float(total)) confusions = numpy.zeros([10, 10], numpy.int32) bundled = zip(predictions, labels) for predicted, actual in bundled: confusions[int(predicted), int(actual)] += 1 return error, confusions
import matplotlib.pyplot as plt %matplotlib inline NUM_LABELS = 10 # change it according to num_class in your dataset test_error, confusions = error_rate(numpy.asarray(preds), numpy.asarray(labels)) print('Test error: %.1f%%' % test_error) plt.xlabel('Actual') plt.ylabel('Predicted') plt.grid(False) plt.xticks(numpy.arange(NUM_LABELS)) plt.yticks(numpy.arange(NUM_LABELS)) plt.imshow(confusions, cmap=plt.cm.jet, interpolation='nearest'); for i, cas in enumerate(confusions): for j, count in enumerate(cas): if count > 0: xoff = .07 * len(str(count)) plt.text(j-xoff, i+.2, int(count), fontsize=9, color='white')
ちらほら外れたものもありますが、大部分は対角線上に分布していますね。 やはり形が似ている4と9みたいな似ているものはやはり分類も難しいんでしょうか、他と比べて誤分類が目立ってますね。
エンドポイントの削除
余分なお金を使わないように、エンドポイントを削除します。
import sagemaker sagemaker.Session().delete_endpoint(xg_predictor.endpoint)
まとめ
Amazon SageMakerの組み込みアルゴリズムの一つであるXGBoostの分類モデルを用いることで、MNISTの手書き文字を分類することができました。
分類ができるアルゴリズムは色々ありますが、XGBoostはその中でも特に人気が高く、有効な手法です。しかし、チューニングすべきパラメータが多く、使いこなすのは少々難しいかもしれません。これを機に色々な例に対して試してみたいですね。
以下シリーズではAmazon SageMakerのその他の組み込みアルゴリズムについても解説しています。宜しければ御覧ください。