Amazon SageMakerでXGBoostをフレームワークとして使ってみた – 機械学習 on AWS Advent Calendar 2019

どうも、DA事業本部の大澤です。

当エントリは『機械学習 on AWS Advent Calendar 2019』の2日目です。

今回は「Amazon SageMakerでXGBoostをフレームワークとして使ってみた」についてご紹介します。

やってみる

準備

使用するライブラリを読み込み、パラメータ等を定義しておきます。S3のバケット名とや接頭辞、IAMロールについては環境に応じて変更してください。

from sklearn import datasets, model_selection
import sagemaker
from datetime import datetime
from os import path
import xgboost as xgb
import tarfile
import pickle
import pandas as pd
from sagemaker.predictor import csv_serializer
sm_session = sagemaker.Session()

# データ等の保存場所
bucket = 'バケット名'
prefix = 'xgboost_framework/'+datetime.now().strftime('%Y%m%d-%H%M%S')
s3_prefix = f's3://{bucket}/{prefix}'

# 学習時にSageMakerが使用するIAMロール
role = 'arn:aws:iam::accountid:role/service-role/AmazonSageMaker-ExecutionRole'
# role = sagemaker.get_execution_role()

データ作成

scikit-learnのデータセットからirisデータを読み込み、学習用、検証用、テスト用に分割します。

iris = datasets.load_iris()
df = pd .DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
train, test = model_selection.train_test_split(df, test_size=0.2)
train, validation = model_selection.train_test_split(train, test_size=0.3)
df

CSV形式で保存し、S3にアップロードします。

train.to_csv('train.csv', index=False)
validation.to_csv('validation.csv', index=False)
test.to_csv('test.csv', index=False)

train_path = sm_session.upload_data(path='train.csv', bucket=bucket, key_prefix=prefix)
validation_path = sm_session.upload_data(path='validation.csv', bucket=bucket, key_prefix=prefix)
test_path = sm_session.upload_data(path='test.csv', bucket=bucket, key_prefix=prefix)

学習/推論用スクリプト

SageMakerでの学習と推論で用いるスクリプトファイルを作成します。

学習時にはハイパラやデータの情報などを引数としてスクリプトが実行されます。今回は指定された場所にあるCSV形式のデータを読み込み、XGBoostで学習し、テストデータを評価し、モデルをpickle化して保存するといった処理です。 推論時にはこのスクリプトがモジュールとして扱われ、model_fninput_fnなどといった関数が呼び出されます。未定義の場合にはデフォルトとして定義されたものが使用されます。今回は未定義の場合にエラーとなるmodel_fnだけ定義しています。

エントリポイントに関する詳しい説明についてはドキュメントをご覧ください。XGBoostの項目は今の所なさそうでした。scikit-learnを使用する場合とほぼ同様なので、そちらのドキュメントを参考になります。

import xgboost as xgb
import argparse
import os
from os import path
import logging
import pickle
import json
import pandas as pd
import glob

logger = logging.getLogger(__name__)

def _read_csvs_to_dmatrix(dir_path, label='target'):
    dfs = []
    for file in glob.glob(path.join(dir_path, '*.csv')):
        dfs.append(pd.read_csv(file))
    df = pd.concat(dfs, axis=0, ignore_index=True)
    return xgb.DMatrix(df.drop(label, axis=1), df[label])

def main(args):

    params = {'objective': 'multi:softmax', 'eval_metric': ['mlogloss', 'merror'], 'num_class': 3}
    if args.num_round:
        params['num_round'] = args.num_round
    if args.max_depth:
        params['max_depth'] = args.max_depth
    if args.eta:
        params['eta'] = args.eta

    dtrain = _read_csvs_to_dmatrix(args.train)

    evallist = [(dtrain, 'train')]


    if args.validation:
        dvalidation = _read_csvs_to_dmatrix(args.validation)
        evallist.append((dvalidation, 'validation'))

    bst = xgb.train(params, dtrain, args.num_round, evallist)


    if args.test:
        dtest = _read_csvs_to_dmatrix(args.test)
        logger.info(bst.eval(dtest, 'test'))

    with open(path.join(args.model_dir, 'model.pickle'), 'wb') as f:
        pickle.dump(bst, f)


# 推論用エンドポイントを作成する際に用いるモデルの読み込み用処理
def model_fn(model_dir):
    with open(path.join(model_dir, 'model.pickle'), 'rb') as f:
        model = pickle.load(f)
    return model


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # ハイパーパラメータ
    parser.add_argument('--num_round', type=int, default=10)
    parser.add_argument('--max_depth', type=int)
    parser.add_argument('--eta', type=float)
    # モデルの出力先パス
    parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
    # モデル以外のデータの出力パス
    parser.add_argument('--output_dir', type=str, default=os.environ.get('SM_OUTPUT_DIR'))
    # 学習データの入力パス
    parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))
    parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION'))
    parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_TEST'))

    parser.add_argument('--log_level', type=int, default=os.environ.get('SM_LOG_LEVEL'))
    args = parser.parse_args()


    logging.basicConfig(level=args.log_level)

    main(args)

学習

SageMaker Python SDKのXGBoost用のEstimatorで必要な項目を設定し、学習を開始します。 エントリポイントとして、先ほど作成したtrain.pyを指定します。

train_instance_type = 'ml.c5.2xlarge' # 'local'とすることで、ローカルでの実行が可能
hyperparameters = {'max_depth': 2, 'eta': 1, 'num_round':2}
estimator = sagemaker.xgboost.XGBoost(entry_point = 'train.py',
                    output_path=s3_prefix,
                    train_instance_type=train_instance_type,
                    train_instance_count=1,
                    hyperparameters=hyperparameters,
                    role=role,
                    base_job_name='xgboost-framework-test',
                    train_max_run=3600, # トレーニングジョブの最大実行時間
                    train_use_spot_instances=True, # マネージドスポットトレーニングを有効化
                    train_max_wait=3600, # トレーニングジョブの最大待機時間 (>=train_max_run)
                    framework_version='0.90-1',
                    py_version='py3')
estimator.fit({'train':train_path,
               'validation':validation_path,
               'test':test_path})

fitを叩くことでSageMakerが学習に使用するインスタンスを用意し、train.pyに記した学習処理を実行してくれます。学習ジョブは数分ほどで完了します。

モデルの読み込み

学習させたモデルをダウンロードしてきて、ローカルで使ってみます。

model_uri = estimator.model_data
model_path = path.join('./', path.basename(model_uri))
sagemaker.utils.download_file_from_url(model_uri, model_path, sm_session)
with tarfile.open(model_path, 'r') as tarf:
    model = pickle.load(tarf.extractfile(tarf.getmembers()[0]))

テスト用データを推論してみます。

dtest = xgb.DMatrix(test.drop('target', axis=1), test['target'])
model.predict(dtest)

推論用エンドポイント

推論用エンドポイントを作成します。推論の際も先ほど作成したtrain.pyが自動的にモジュールとして使用されます。

predictor = estimator.deploy(
    instance_type='ml.c5.large', # 'local'とすることで、ローカルでの実行が可能
    initial_instance_count=1)

エンドポイントの作成は数分ほどで完了します。CSV用のシリアライザを設定し、CSV形式でエンドポイントにデータを投げるようにします。

predictor.serializer = csv_serializer
predictor.content_type = 'text/csv'

テストデータを推論してみます。先ほどシリアライザを設定したため、ndarrayを自動的にCSVに変換しエンドポイントに投げてくれます。

predictor.predict(test.drop('target', axis=1).values)

無事推論できたので、エンドポイントを削除します。

estimator.delete_endpoint()

さいごに

『機械学習 on AWS Advent Calendar 2019』 2日目、「Amazon SageMakerでXGBoostをフレームワークとして使ってみた」についてお伝えしました。ローカルで動かす感じとそこまで大きな変更をせずに、SageMakerでXGBoostを使うことができました。

参考