この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
SageMaker Sparkを使用して、Amazon SageMaker上で手書き数字の分類モデルを作成するノートブックをやってみたので、その内容を紹介します。
概要
SageMaker Sparkを使って、Sparkをローカル上に展開し、MNISTのデータを読み込み、SageMaker上でXGBoostの分類モデルの学習とホストを行います。その後、作成したモデルを使ってテストデータの分類を行います。
- SageMaker Spark
- Apache SparkのSageMaker用のオープンソースのライブラリです。Sparkに読み込んだデータを使ったSageMaker上での学習や、SparkのMLlibとSageMakerを連携させること等が出来ます。
- MNISTのデータセット
- 手書き数字画像とラベルデータを含んだデータセットです。データ分析のチュートリアルでよく使われるデータセットの一つです。
- XGBoost
- 勾配ブースティングツリーという理論のオープンソースの実装で、分類や回帰によく使われる機械学習アルゴリズムの一つです。
基本的にはawslabsのノートブックに沿って進めますが、一部変更している箇所があります。
やってみた
ノートブックの作成
SageMakerのノートブックインスタンスを立ち上げて、
SageMaker Examples
↓
Sagemaker Spark
↓
pyspark_mnist_xgboost.ipynb
↓
use
でサンプルからノートブックをコピーして、開きます。
ノートブックインスタンスの作成についてはこちらをご参照ください。
セットアップ
IAMロールの取得とSparkセッションの構築を行います。 今回はローカル上にSparkアプリケーションを実行し、セッションに繋ぎます。 リモートのSparkクラスタと接続する場合はAmazon SageMaker PySparkのGitHubでの解説をご参照ください。
import os
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import sagemaker
from sagemaker import get_execution_role
import sagemaker_pyspark
# 実行しているIAMロールを取得する
role = get_execution_role()
# SageMaker Sparkが依存するjarを取得する
jars = sagemaker_pyspark.classpath_jars()
classpath = ":".join(sagemaker_pyspark.classpath_jars())
# ローカルで実行するSparkアプリケーションとのセッションを取得する
spark = SparkSession.builder.config("spark.driver.extraClassPath", classpath)\
.master("local[*]").getOrCreate()
データの読み込み
MNISTのデータセットをSpark Dataframeに読み込みます。 データセットはLibSVM形式のものがS3上で公開されています。 LibSVM形式についてはこちらのエントリをご覧ください。
学習と推論のデータセットと使用するには2種類のカラムが必要です。一つはダブル型のカラム(デフォルトではlabel
という名前)、二つ目はダブル型のベクトル(デフォルトではfeatures
という名前)です。
import boto3
region = boto3.Session().region_name
trainingData = spark.read.format('libsvm')\
.option('numFeatures', '784')\
.option('vectorType', 'dense')\
.load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))
testData = spark.read.format('libsvm')\
.option('numFeatures', '784')\
.option('vectorType', 'dense')\
.load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))
trainingData.show()
モデルの学習とホスティング
Estimatorを作成し、学習とモデルをホストするエンドポイントの作成を行います。 Estimatorには学習とエンドポイントのパラメータ、学習時のハイパーパラメータを設定します。 XGBoostのハイパーパラメータについてはドキュメントをご参照ください。
import random
from sagemaker_pyspark import IAMRole, S3DataPath
from sagemaker_pyspark.algorithms import XGBoostSageMakerEstimator
from sagemaker_pyspark.S3Resources import S3DataPath
# モデルアーティファクトの出力先を指定(ノートブックにはなく、追加した内容です。)
output_s3_data = S3DataPath('bucket_name', 'sagemaker/spark-xgb/output')
# estimatorの設定
xgboost_estimator = XGBoostSageMakerEstimator(
sagemakerRole=IAMRole(role), # 学習時とエンドポイントの作成時に使用するIAMロール
trainingInstanceType='ml.m4.xlarge', # 学習に使用するインスタンスタイプ
trainingInstanceCount=1, # 学習に使用するインスタンス数
endpointInstanceType='ml.m4.xlarge', # モデルをホストするエンドポイントのインスタンスタイプ
endpointInitialInstanceCount=1, # モデルをホストするエンドポイントのインスタンス数
trainingOutputS3DataPath = output_s3_data) # モデルアーティファクトの出力先
# ハイパーパラメータの設定
xgboost_estimator.setEta(0.2)
xgboost_estimator.setGamma(4)
xgboost_estimator.setMinChildWeight(6)
xgboost_estimator.setSilent(0)
xgboost_estimator.setObjective("multi:softmax")
xgboost_estimator.setNumClasses(10)
xgboost_estimator.setNumRound(10)
# 学習処理開始
model = xgboost_estimator.fit(trainingData)
推論
テストデータをまとめて分類(推論)します。この処理を行うために内部的には、featureカラムをLibSVM形式に変換して、モデルをホストするエンドポイントに投げます。その後、エンドポイントでモデルが分類した結果をCSV形式で受け取って、DataFrameに格納します。この処理をtransform
がまとめてやってくれます。便利です。
※ ノートブックではtrainDataを読み込んでいますが、testDataの方が適切だと思われるので、testDataを読み込むように変更しています。
transformedData = model.transform(testData)
transformedData.show()
入力データに加えてprediction
カラムが追加されています。
分類結果ごとに入力画像を表示してみて、正しく分類できたかを確認します。
from pyspark.sql.types import DoubleType
import matplotlib.pyplot as plt
import numpy as np
# 数字を表示するための補助関数
def show_digit(img, caption='', xlabel='', subplot=None):
if subplot==None:
_,(subplot)=plt.subplots(1,1)
imgr=img.reshape((28,28))
subplot.axes.get_xaxis().set_ticks([])
subplot.axes.get_yaxis().set_ticks([])
plt.title(caption)
plt.xlabel(xlabel)
subplot.imshow(imgr, cmap='gray')
# 入力画像データを取得
images = np.array(transformedData.select("features").cache().take(250))
# 分類結果を取得
clusters = transformedData.select("prediction").cache().take(250)
# 分類結果ごとに入力画像を表示する
for cluster in range(10):
print('\n\n\nCluster {}:'.format(int(cluster)))
digits=[ img for l, img in zip(clusters, images) if int(l.prediction) == cluster ]
height=((len(digits) - 1) // 5) + 1
width=5
plt.rcParams["figure.figsize"] = (width,height)
_, subplots = plt.subplots(height, width)
subplots=np.ndarray.flatten(subplots)
for subplot, image in zip(subplots, digits):
show_digit(image, subplot=subplot)
for subplot in subplots[len(digits):]:
subplot.axis('off')
plt.show()
実際には0~9までの数字に対して結果が表示されますが、ここでは0と9の結果だけ紹介します。 0に分類された入力画像は全て正しく0を表していそうです。
9に分類された入力画像は基本的に正しそうですが、7や8が紛れ込んでいます。
概ね高い精度で分類出来ていそうです。
エンドポイントの削除
余計な費用が掛からないようにエンドポイントを削除します。
from sagemaker_pyspark import SageMakerResourceCleanup
resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient)
resource_cleanup.deleteResources(model.getCreatedResources())
おわりに
SageMaker Sparkライブラリを使用して、Amazon SageMaker上で手書き数字の分類を行えました。SageMaker Sparkを使うことで、Amazon SageMakerのモデルをSpark上で扱うことができます。大きなデータを扱ったり、Apache SparkのMLlibとAmazon SageMakerを連携させる際には非常に便利そうです。
これからAmazon SageMakerとApache Sparkを試してみたいと思っている方の参考になれば幸いです。 最後までお読み頂き有難うございましたー!