Amazon SageMakerで学習過程の可視化が出来るようになりました!
こんにちは、大阪DI部の大澤です。
Amazon SageMakerはJupyterノートブックを使ったデータ探索から機械学習モデルの作成、エンドポイントの展開等が行える機械学習のフルマネージドサービスです。
先日、そのSageMakerでモデルを学習させた際にlossや精度といったメトリクスの変化を簡単に可視化することが、SageMakerとCloudWatchの組み合わせで可能となりました!
メトリクスの変化の可視化によってモデルの学習状態が分かりやすくなり、過学習なのか未学習なのかの判断を下しやすくなります。個人的に非常に嬉しい機能です!
機能概要
今回更新された機能は、SageMakerを使った学習において、予め指定したlossや精度といったメトリクスをCloudWatchで確認できるという内容です。また、SageMaker Python SDKを使ってメトリクスの取得・描画を行うことも可能です。
メトリクスの設定項目
学習ジョブ作成時に以下の2項目を各メトリクスに対して設定することで、学習ジョブの実行中にCloudWatchへデータが送られるようになります。
- メトリクス名
- 文字通り、メトリクスを特定するための名前です。 例: train:error
- トレーニングメトリクスパターン
- 学習時に出力されるログから対象のメトリクスを抽出するための正規化表現です。 例: .*\\[[0-9]+\\]#011train-error:(\\S+).*
やってみる
xgboostの分類モデルを学習ジョブを実行し、学習過程を見てみます。
学習ジョブの実行
マネジメントコンソールから学習ジョブを作成します。
組み込みアルゴリズムを使用するかそうでないかで設定内容が変わります。いずれの場合もアルゴリズムの設定箇所以外は通常のジョブ作成と同じです。
組み込みアルゴリズムの場合
今回は組み込みアルゴリズムのxgboostを選択します。組み込みアルゴリズムの場合はすでにメトリクスが設定済みのため、自分で設定する必要はありません。
アルゴリズムの選択箇所でCustomを選ぶことで、自らでメトリクスの設定を行うことができます。
組み込みアルゴリズム以外の場合
アルゴリズムの場所で独自の ECR コンテナパスを使用します。を選択することで組み込みアルゴリズム以外を使った学習ジョブを設定できます。その場合は、自らでメトリクスの設定が必要になります。
SageMaker Python SDKや低レベルAPIを用いる場合
SageMaker Python SDKの場合にはEstimatorオブジェクトを初期化する際のmetric_definitions、低レベルAPIの場合にはAlgorithmSpecificationのMetricDefinitionsにメトリクス名とトレーニングメトリクスパターンの設定が必要になります。組み込みアルゴリズムの場合はマネジメントコンソールでジョブを作成する場合と同様に、設定は不要です。
詳細についてはドキュメントをご覧ください。
学習過程の確認
CloudWatch上での確認
学習ジョブを実行し、しばらくするとCloudWatchにデータが送られて確認出来るようになります。学習ジョブの詳細画面からCloudWatchのメトリクス画面に行くことができます。
表示可能なメトリクスが一覧化されているため、他のCloudWatchメトリクスと同じように、メトリクスを選択してグラフに描画することができます。
validation:merrorを期間:5分で描画。
学習時間が短い場合には全てのエポックでのメトリクスが表示されない場合があります。各メトリクスの期間を狭めることで描画点数が増えることがあります。 以下のグラフは上と同じメトリクスを期間:1分で描画したものです。描画点数が増えていることが分かると思います。
SageMaker Python SDKを使って確認
SageMaker Python SDKを使うことで、メトリクスの取得と可視化が可能です。 ※SDKのバージョンが低い場合はエラーが出ます。エラーが出た場合はSDKをアップグレードすることで動くようになることがあります。
import boto3 import sagemaker from sagemaker.analytics import TrainingJobAnalytics # ジョブ名 training_job_name = 'job_name' # メトリクス名 metric_name = 'validation:merror' # セッション作成 botosess = boto3.Session(region_name='ap-northeast-1') sess = sagemaker.Session(botosess) # メトリクスデータをデータフレーム形式で取得 metrics_dataframe = TrainingJobAnalytics(training_job_name=training_job_name,metric_names=[metric_name], sagemaker_session=sess).dataframe() # プロット metrics_dataframe.plot(x='timestamp', y='value', legend=False).set_ylabel(metric_name)
最後に
Amazon SageMakerを使うことで物体検出や分類、回帰といった色々なモデルを簡単に学習させることができるんですが、これまでの場合は学習中のlossや精度の変化はCloudWatch Logsからテキストベースで確認するのが主な方法でした。それが今回のアップデートによって、CloudWatchのグラフで学習過程を可視化できるようになりました。lossや精度の確認は機械学習をする上で重要です。この機能を活用して過学習や未学習のモデルを作らないよう、適切な学習を行いたいですね。
最後までお読みくださりありがとうございました〜!