AmazonSageMakerのXGBoostでMNISTの手書き文字を分類してみた

AmazonSageMakerのXGBoostでMNISTの手書き文字を分類してみた

Clock Icon2018.08.13

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、大澤です。

当エントリではAmazon SageMakerの組み込みアルゴリズムの1つ、「XGBoost」を用いた分類方法についてご紹介していきたいと思います。

「組み込みアルゴリズム」の解説については下記エントリをご参照ください。

目次

概要説明:XGBoostとは

XGBoostとは勾配ブースティングツリーと言う理論のオープンソースの実装で、分類や回帰に使われています。 当エントリは分類を使った例について紹介します。

組み込みアルゴリズム:XGBoostの実践

Amazon SageMaker ExamplesのXGBoostの例を参考にして進めていきます。 基本的には同じ内容ですが、例ではboto3のSageMaker用低レベルAPIを使用しているのに対して、当エントリではSageMakerのPython用SDKを使ったものとなります。

定番のMNISTの手書き数字のデータを元に、書かれている数字の分類するという内容です。

ノートブックの作成

SageMakerのノートブックインスタンスを立ち上げて、 SageMaker Examples ↓ Introduction to Amazon algorithms ↓ xgboost_mnist.ipynb ↓ use でサンプルからノートブックをコピーして、開きます。 ノートブックインスタンスの作成についてはこちらをご参照ください。

環境変数とロールの確認

学習データ等を保存するS3のバケット名と保存オブジェクト名の接頭辞を決めます。

%%time

import os
import boto3
import re
import copy
import time
from time import gmtime, strftime
from sagemaker import get_execution_role

role = get_execution_role()

region = boto3.Session().region_name

bucket='<your_s3_bucket_name_here>' # put your s3 bucket name here, and create s3 bucket
prefix = 'sagemaker/DEMO-xgboost-multiclass-classification'
# customize to your bucket where you have stored the data
bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)

データ取得

MNISTのデータを取得してきます。

%%time
import pickle, gzip, numpy, urllib.request, json

# Load the dataset
urllib.request.urlretrieve("http://deeplearning.net/data/mnist/mnist.pkl.gz", "mnist.pkl.gz")
f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
f.close()

MNISTのデータをLIBSVM形式に加工して、S3に保存します。 まずはそのための関数を定義します。

%%time

import struct
import io
import boto3

 
def to_libsvm(f, labels, values):
     f.write(bytes('\n'.join(
         ['{} {}'.format(label, ' '.join(['{}:{}'.format(i + 1, el) for i, el in enumerate(vec)])) for label, vec in
          zip(labels, values)]), 'utf-8'))
     return f


def write_to_s3(fobj, bucket, key):
    return boto3.Session().resource('s3').Bucket(bucket).Object(key).upload_fileobj(fobj)

def get_dataset():
  import pickle
  import gzip
  with gzip.open('mnist.pkl.gz', 'rb') as f:
      u = pickle._Unpickler(f)
      u.encoding = 'latin1'
      return u.load()

def upload_to_s3(partition_name, partition):
    labels = [t.tolist() for t in partition[1]]
    vectors = [t.tolist() for t in partition[0]]
    num_partition = 5                                 # partition file into 5 parts
    partition_bound = int(len(labels)/num_partition)
    for i in range(num_partition):
        f = io.BytesIO()
        to_libsvm(f, labels[i*partition_bound:(i+1)*partition_bound], vectors[i*partition_bound:(i+1)*partition_bound])
        f.seek(0)
        key = "{}/{}/examples{}".format(prefix,partition_name,str(i))
        url = 's3n://{}/{}'.format(bucket, key)
        print('Writing to {}'.format(url))
        write_to_s3(f, bucket, key)
        print('Done writing to {}'.format(url))

def download_from_s3(partition_name, number, filename):
    key = "{}/{}/examples{}".format(prefix,partition_name, number)
    url = 's3n://{}/{}'.format(bucket, key)
    print('Reading from {}'.format(url))
    s3 = boto3.resource('s3')
    s3.Bucket(bucket).download_file(key, filename)
    try:
        s3.Bucket(bucket).download_file(key, 'mnist.local.test')
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] == "404":
            print('The object does not exist at {}.'.format(url))
        else:
            raise        
        
def convert_data():
    train_set, valid_set, test_set = get_dataset()
    partitions = [('train', train_set), ('validation', valid_set), ('test', test_set)]
    for partition_name, partition in partitions:
        print('{}: {} {}'.format(partition_name, partition[0].shape, partition[1].shape))
        upload_to_s3(partition_name, partition)

先ほど定義した関数を使って、加工・保存処理を実行します。

%%time

convert_data()

学習

XGBoost用のコンテナの名前を取得します。

from sagemaker.amazon.amazon_estimator import get_image_uri
container = get_image_uri(boto3.Session().region_name, 'xgboost')

ハイパーパラメータや出力先、学習用コンテナイメージ等の学習に必要な設定を行い、学習処理を実行します。 ハイパーパラメータに関する詳細はドキュメントをご確認ください。

import boto3
import sagemaker
 
sess = sagemaker.Session()

 
xg = sagemaker.estimator.Estimator(container,
                                   role, 
                                   train_instance_count=1, 
                                   train_instance_type='ml.m4.4xlarge',
                                   output_path=bucket_path + "/"+ prefix + "/xgboost",
                                   train_max_run = 900,
                                   sagemaker_session=sess)
xg.set_hyperparameters(max_depth=5,
        eta=0.2,
        gamma=4,
        min_child_weight=6,
        silent=0,
        objective= "multi:softmax",
        num_class=10,
        num_round=10)

train_data = sagemaker.session.s3_input( bucket_path + "/"+ prefix+ '/train/', distribution='FullyReplicated', 
                        content_type='libsvm', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input( bucket_path + "/"+ prefix+ '/validation/', distribution='FullyReplicated', 
                        content_type='libsvm', s3_data_type='S3Prefix')

data_channels = {'train': train_data, 'validation': validation_data}
 
xg.fit(inputs=data_channels, logs=True)

モデルの展開

エンドポイントを作成し、先ほど学習したモデルを展開します。

xg_predictor = xg.deploy(initial_instance_count=1,
                           instance_type='ml.m4.xlarge')

モデルの確認

エンドポイントへのリクエスト形式を指定します。

xg_predictor.content_type = 'text/x-libsvm'

テスト用にデータをダウンロードし、そこから確認用に1行だけ書き出します。

download_from_s3('test', 0, 'mnist.local.test') # reading the first part file within test
!head -1 mnist.local.test > mnist.single.test

エンドポイントにデータを投げて予測結果を受け取って表示してみます。

%%time
import json

file_name = 'mnist.single.test' #customize to your test file 'mnist.single.test' if use the data above

with open(file_name, 'r') as f:
    payload = f.read()

result = xg_predictor.predict(payload).decode('utf-8')
print('Predicted label is {}.'.format(result))

分類ができてる感じがしますね。

次はテストデータ全体を分類してみます。 まずは関数定義。

import sys
def do_predict(data):
    payload = '\n'.join(data)
    result = xg_predictor.predict(payload).decode('utf-8')
    preds = [float(num) for num in result.split(',')]
    return preds

def batch_predict(data, batch_size):
    items = len(data)
    arrs = []
    for offset in range(0, items, batch_size):
        arrs.extend(do_predict(data[offset:min(offset+batch_size, items)]))
        sys.stdout.write('.')
    return(arrs)

定義した関数を使って分類処理を実行し、誤差率を表示します。

%%time
import json

file_name = 'mnist.local.test'
with open(file_name, 'r') as f:
    payload = f.read().strip()

labels = [float(line.split(' ')[0]) for line in payload.split('\n')]
test_data = payload.split('\n')
preds = batch_predict(test_data, 100)

print ('\nerror rate=%f' % ( sum(1 for i in range(len(preds)) if preds[i]!=labels[i]) /float(len(preds))))

誤差率が約11%とそれなりの結果が出ました。 分類したデータを俯瞰するために分類結果と正しいラベルの混同行列を表示してみます。

import numpy
def error_rate(predictions, labels):
    """Return the error rate and confusions."""
    correct = numpy.sum(predictions == labels)
    total = predictions.shape[0]

    error = 100.0 - (100 * float(correct) / float(total))

    confusions = numpy.zeros([10, 10], numpy.int32)
    bundled = zip(predictions, labels)
    for predicted, actual in bundled:
        confusions[int(predicted), int(actual)] += 1
    
    return error, confusions
import matplotlib.pyplot as plt
%matplotlib inline  

NUM_LABELS = 10  # change it according to num_class in your dataset
test_error, confusions = error_rate(numpy.asarray(preds), numpy.asarray(labels))
print('Test error: %.1f%%' % test_error)

plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.grid(False)
plt.xticks(numpy.arange(NUM_LABELS))
plt.yticks(numpy.arange(NUM_LABELS))
plt.imshow(confusions, cmap=plt.cm.jet, interpolation='nearest');

for i, cas in enumerate(confusions):
    for j, count in enumerate(cas):
        if count > 0:
            xoff = .07 * len(str(count))
            plt.text(j-xoff, i+.2, int(count), fontsize=9, color='white')

ちらほら外れたものもありますが、大部分は対角線上に分布していますね。 やはり形が似ている4と9みたいな似ているものはやはり分類も難しいんでしょうか、他と比べて誤分類が目立ってますね。

エンドポイントの削除

余分なお金を使わないように、エンドポイントを削除します。

import sagemaker

sagemaker.Session().delete_endpoint(xg_predictor.endpoint)

まとめ

Amazon SageMakerの組み込みアルゴリズムの一つであるXGBoostの分類モデルを用いることで、MNISTの手書き文字を分類することができました。

分類ができるアルゴリズムは色々ありますが、XGBoostはその中でも特に人気が高く、有効な手法です。しかし、チューニングすべきパラメータが多く、使いこなすのは少々難しいかもしれません。これを機に色々な例に対して試してみたいですね。

以下シリーズではAmazon SageMakerのその他の組み込みアルゴリズムについても解説しています。宜しければ御覧ください。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.