SPL04-R:  Creating and Tuning Models with Amazon SageMakerの内容を復習してみた #reinvent

SPL04-R: Creating and Tuning Models with Amazon SageMakerの内容を復習してみた #reinvent

Clock Icon2018.12.20

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

はじめに

好物はインフラとフロントエンドのかじわらゆたかです。 re:invent で受講したSpotlight labの内容を復習してみたので、そのエントリーになります。

概要

このSpotlight Labで受講した内容としては下記の通りです。 * 毎年14% の顧客が離脱している * どの顧客が離脱するかを機械学習を用いて予測したい

今回、この予測モデルを作成するアルゴリズムとしてはXGBoostを用います。

この演習に用いたJupyterNotebookはこちらにあります。

GitHub等のRepositoryとSagemaker連携ができるようになってますので、そちらを用いることで簡単にCloneして、当該のNotebookを立ち上げることができます。

Amazon SageMakerにてGit連携が行えるようになりました。#reinvent

Setup

Setup に記載されている項目を実行します。 bucketの変数には自分が用いるBucket名を代入します。

Data

使用するデータセットは公開されており、Daniel T. Laroseの「Discovering Knowledge in Data」という書籍に記載されているものになります。 これは、カリフォルニア大学アーバイン校の機械学習データセットの著者に帰属しています。

データ構造

データ構造は下記のようになっています。

項目名 項目内容
State 居住する米国の州で、2文字の略語で示されます。 例: OH , NJ
Account Length このアカウントが有効になっている日数
Area Code 3桁の市外局番
Phone 7桁の電話番号
Int’l Plan 顧客が国際電話プランを持っているかどうか
VMail Plan 顧客にボイスメール機能があるかどうか
VMail Message 一月あたりのボイスメールメッセージの平均数
Day Mins 一日あたりの通話時間の合計数
Day Calls 一日あたりの総コール数
Day Charge 昼間の通話料金
Eve Mins, Eve Calls, Eve Charge 夕方にかけられた通話料金
Night Mins, Night Calls, Night Charge 夜間にかけられたコールの料金
Intl Mins, Intl Calls, Intl Charge 国際通話料金
CustServ Calls カスタマーサービスにかけられたコールの数
Churn? 顧客がサービスを辞めたかどうか

Churn? が機械学習で予測させたい属性となります。 今回のChurn? はバイナリのため、2クラス分類で分類することが可能です。

各データの分布を確認します。

for column in churn.select_dtypes(include=['object']).columns:
    display(pd.crosstab(index=churn[column], columns='% observations', normalize='columns'))

上記のようにすることで、各カラムのデータがどのような頻度で出ているかを確認することができます。

また、下記のようにすることで各項目の平均、標準偏差等を求めることが可能です。

display(churn.describe())

出た各項目は以下のようになっています。

【Python】pandasのdescribeで出力される項目の意味について【データ分析】

項目名 意味
count そのカラムの件数
mean 平均
std 標準偏差
min 最小値
25% 第一四分位数
50% 第二四分位数
75% 第三四分位数
max 最大値

ヒストグラムを表示させます。

%matplotlib inline
hist = churn.hist(bins=30, sharey=True, figsize=(10, 10))

分布から下記の事がわかります。

  • Stateカラムは均等に配分されている模様です
  • Phoneはユニークな値をとるが、番号についてのコンテキストがない場合は用いないほうが良いでしょう
  • ChurnがTrueの人間が14%ほどいますが、これについては特に気にする必要はないでしょう
  • 数字のほとんどきれいな鐘状に配置されているがわかります。
  • VMail Messageの項目はその中で例外的な配置となります。

不要な項目を除外していきます。

まず、上記にあったとおりPhoneの項目を除外します。

churn = churn.drop('Phone', axis=1)

次にAreaCodeの項目は数字ではなくオブジェクトとして扱うように型変換を行います。

churn['Area Code'] = churn['Area Code'].astype(object)

各カラムとChurn?のクロス集計を行い関連性を見ていきます。

for column in churn.select_dtypes(include=['object']).columns:
    if column != 'Churn?':
        display(pd.crosstab(index=churn[column], columns=churn['Churn?'], normalize='columns'))

for column in churn.select_dtypes(exclude=['object']).columns:
    print(column)
    hist = churn[[column, 'Churn?']].hist(by='Churn?', bins=30)
    plt.show()

以下のことがわかります。

  • stateArea Codeを見る限り、地理的に均等に分散している

  • Int’l Planは関連がありそうである
  • VVMail Planは関連が低そうである

  • 解約を行った人のCustServ Callsが平均より高いか低いといった特徴がある

加えて、解約した人間はDay MinsDay Chargeといった非常によく似た特徴を持っていることもわかります。

値ごとの関連を見てみます。

display(churn.corr())
pd.plotting.scatter_matrix(churn, figsize=(12, 12))
plt.show()

相関関係になっている組み合わせがいくつかあることがわかります。 例: Day ChargeDay Mins こういった組み合わせは機械学習の際に致命的な問題を発生することがあるので、 値に含ませないことにします。

churn = churn.drop(['Day Charge', 'Eve Charge', 'Night Charge', 'Intl Charge'], axis=1)

トレーニング用データ作成

ここまで精査したデータをトレーニング用のデータと検証用のデータとして作成していきS3にアップロードします。

model_data = pd.get_dummies(churn)
model_data = pd.concat([model_data['Churn?_True.'], model_data.drop(['Churn?_False.', 'Churn?_True.'], axis=1)], axis=1)
train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), [int(0.7 * len(model_data)), int(0.9 * len(model_data))])
train_data.to_csv('train.csv', header=False, index=False)
validation_data.to_csv('validation.csv', header=False, index=False)

Train

トレーニングを行っていきます。

from sagemaker.amazon.amazon_estimator import get_image_uri
container = get_image_uri(boto3.Session().region_name, 'xgboost')
s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')
s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')
sess = sagemaker.Session()

xgb = sagemaker.estimator.Estimator(container,
                                    role, 
                                    train_instance_count=1, 
                                    train_instance_type='ml.m4.xlarge',
                                    output_path='s3://{}/{}/output'.format(bucket, prefix),
                                    sagemaker_session=sess)
xgb.set_hyperparameters(max_depth=5,
                        eta=0.2,
                        gamma=4,
                        min_child_weight=6,
                        subsample=0.8,
                        silent=0,
                        objective='binary:logistic',
                        num_round=100)

xgb.fit({'train': s3_input_train, 'validation': s3_input_validation}) 

用いたハイパーパラメータの解説は以下を参考にしてください。 XGBoostのハイパーパラメータ

Host

作成したModelをデプロイします。

xgb_predictor = xgb.deploy(initial_instance_count=1,
                           instance_type='ml.m4.xlarge')

Evaluate

検証用の関数を実装し、実際に検証用のデータを用いて推論を行い、推論結果をクロス集計します。

xgb_predictor.content_type = 'text/csv'
xgb_predictor.serializer = csv_serializer
xgb_predictor.deserializer = None

def predict(data, rows=500):
    split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))
    predictions = ''
    for array in split_array:
        predictions = ','.join([predictions, xgb_predictor.predict(array).decode('utf-8')])

    return np.fromstring(predictions[1:], sep=',')

predictions = predict(test_data.as_matrix()[:, 1:])

pd.crosstab(index=test_data.iloc[:, 0], columns=np.round(predictions), rownames=['actual'], colnames=['predictions'])

この例では、48名の顧客がサービスを辞め、39名に対しては推論を行うことができました。 ですが、9名は推論ではサービスを辞めないと判断したが、実際には辞めた顧客ということになります。(false negative) これはXGBoostの推論時に出てきた値を四捨五入で処理した結果であるため、しきい値を調整することでfalse nagativeな判定を減らしていきたいと思います。

まずは、今回の推論の数値の分布を確認します。

plt.hist(predictions)
plt.show()

例えば、これのしきい値を0.3に変更することでクロス集計の結果は下記になります。

pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))

今回この調整を行うことで、false nagativeの人間を一人減らすことがでましたが、 他の推論の結果も変わってしまったことがわかります。

コストを最適化するしきい値はどこか?

しきい値の調整はどの層に対して、どれだけのコストを支払うかといって決めることが可能です。 今回の例ですと、以下のようにかかるコストを割り当てるものとします。

  • false nagativeと判定した人間に対しては$500のコストがかかるものとする
    • これは、推論結果は辞めないと判断し、実際は辞めた人間となる
    • こういった顧客は、売上が減ることになり次の顧客の獲得のためのコストが掛かるため一番コストのかかる顧客となる
  • true positive / false positiveと判定した人間に対しては$100 のコストが掛かるものとする
  • true negativeと判定した人間に対してはコストはかけない

この点わかりにくかったのでかけるコストの対応がどのようになるのか、しきい値0.3のデータと併記してみました。

顧客の状態 / 推論結果 離脱しないと予測 離脱すると予測
離脱しない 279 7
離脱した 8 40
顧客の状態 / 推論結果 離脱しないと予測 離脱すると予測
離脱しない $0 $100
離脱した $500 $100
顧客の状態 / 推論結果 離脱しないと予測 離脱すると予測
離脱しない true negative false positive
離脱した false negative true positive

これらの条件で一番コストが安くなるところを求めていきます。

cutoffs = np.arange(0.01, 1, 0.01)
costs = []
for c in cutoffs:
    costs.append(np.sum(np.sum(np.array([[0, 100], [500, 100]]) * 
                               pd.crosstab(index=test_data.iloc[:, 0], 
                                           columns=np.where(predictions > c, 1, 0)))))

costs = np.array(costs)
plt.plot(cutoffs, costs)
plt.show()
print('Cost is minimized near a cutoff of:', cutoffs[np.argmin(costs)], 'for a cost of:', np.min(costs))

これを動かすことで、以下のような結果が得られます。

しきい値0.46あたりがもっともコストが低くなります。

まとめ

実際に機械学習を用いて出てきたモデルをどう経営判断につなげていくかといったことも、 学べたSpotLabとなっていました。 また、Sagemaker Neo版も出ていたので、今回のと違いも見てみたいと思います。

おまけ

今回立ち上げたエンドポイントは削除し、ノートブックインスタンスも課金対象のため不要であれば止めておきます。 エンドポイントの削除は以下になります。

sagemaker.Session().delete_endpoint(xgb_predictor.endpoint)

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.