Amazon SageMaker Debugger で学習の異常を検出し、学習ジョブを停止する – 機械学習 on AWS Advent Calendar 2019 #reinvent

どうも、DA 事業本部の大澤です。

当エントリは『機械学習 on AWS Advent Calendar 2019』の 17 日目のエントリです。

今回は「Amazon SageMaker Debugger で学習の異常を検出し、学習ジョブを停止する」についてご紹介します。

やってみる

Amazon SageMaker Examplesで紹介されている以下のノートブックを実際にやってみます。

次のような流れです。

  • マネジメントコンソールから学習ジョブを停止するための Lambda 関数を作成
    • CloudWatch Events の SageMaker Training Job State Changeをトリガーとして利用
  • ノートブックから TensorFlow モデルの学習ジョブを作成
    • SageMaker Debugger でルールを設定
  • 学習中にルールに関するIssuesFoundになり、学習ジョブが停止するのを確認
    • SageMaker Debugger の設定したルールの評価ログを CloudWatch Logs から確認

Lambda 関数を作成

まずは、マネジメントコンソールから Lambda 関数を作成します。 ランタイムにはPython 3.8、IAM ロールはCreate a new role with basic Lambda permissionsを選択します。IAM ロールはこのままだと SageMaker に関する権限が足りないので、後ほどポリシーをアタッチします。

Lambda 関数が作成できたら、Execution roleからView the {ロール名}で IAM ロールの画面に移動します。

AmazonSageMakerFullAccessポリシーを Lambda 関数の実行ロールに付与します。

Lambda 関数の画面に戻り、スクリプトを入力し、変更内容をSaveします。

使用するスクリプトは次の通りです。 今回作成する Lambda 関数は SageMaker の学習ジョブのステータスが変更した際に実行されるものです。このスクリプトでは、SageMaker Debugger で設定したルールの評価ステータスがIssuesFoundになっている場合に学習ジョブを停止し、それ以外の場合には何もせずに処理を終了します。

import json
import boto3
import logging

def lambda_handler(event, context):
training_job_name = event.get("detail").get("TrainingJobName")
eval_statuses = event.get("detail").get("DebugRuleEvaluationStatuses", None)

    if eval_statuses is None or len(eval_statuses) == 0:
        logging.info("Couldn't find any debug rule statuses, skipping...")
        return {
            'statusCode': 200,
            'body': json.dumps('Nothing to do')
        }

    client = boto3.client('sagemaker')

    for status in eval_statuses:
        if status.get("RuleEvaluationStatus") == "IssuesFound":
            logging.info(
                'Evaluation of rule configuration {} resulted in "IssuesFound". '
                'Attempting to stop training job {}'.format(
                    status.get("RuleConfigurationName"), training_job_name
                )
            )
            try:
                client.stop_training_job(
                    TrainingJobName=training_job_name
                )
            except Exception as e:
                logging.error(
                    "Encountered error while trying to "
                    "stop training job {}: {}".format(
                        training_job_name, str(e)
                    )
                )
                raise e
    return None

では、次に CloudWatch Events でのトリガー作成に移ります。 どういうイベントが来た時に Lambda 関数を実行するかのルールを作成します。

イベントソースはSageMaker Training Job State Changeを利用します。

ターゲットは先ほど作成した Lambda 関数を指定します。

次に名前や説明を入力し、ルールの作成を完了します。

Lambda関数の画面に戻ってみると、 CloudWatch Eventsがトリガーとして設定されています。

モデルの学習

ここからはノートブックでの作業に移ります。TensorFlow のモデルを学習させます。

ライブラリの読み込みと学習用スクリプトのパス、ハイパーパラメータを定義しておきます。 今回はMNISTのデータセットを CNN で分類させます。学習用スクリプトは TensorFlow を使って実装されたmnist_zerocodechange.pyを使用します。

ファイル名にもある通り、SageMaker Debugger を使用するにあたり、学習用スクリプトの変更は不要です。


import boto3
import os
import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker.debugger import Rule, rule_configs

# define the entrypoint script

entrypoint_script='src/mnist_zerocodechange.py'

# these hyperparameters ensure that vanishing gradient will trigger for our tensorflow mnist script

hyperparameters = {
"num_epochs": "10",
"lr": "10.00"
}

続いて、SageMaker Debugger のルールを作成し、学習をハンドルするEstimatorに設定します。今回は勾配消失を検出するルールとロスが減少していないことを検出するルールを設定します。これらのルールは組み込みルールとして用意されているので、専用の処理を書く必要はありません。

rules=[
Rule.sagemaker(rule_configs.vanishing_gradient()),
Rule.sagemaker(rule_configs.loss_not_decreasing())
]

estimator = TensorFlow(
role=sagemaker.get_execution_role(),
base_job_name='smdebugger-demo-mnist-tensorflow',
train_instance_count=1,
train_instance_type='ml.m4.xlarge',
entry_point=entrypoint_script,
framework_version='1.15',
train_volume_size=400,
py_version='py3',
train_max_run=3600,
script_mode=True,
hyperparameters=hyperparameters, ## New parameter
rules = rules
)

実行中にルールの評価ステータス等の確認を行いたいため、学習が完了するまで待機する設定を False にして学習ジョブを実行します。

estimator.fit(wait=False)

設定したルールの評価ステータスを確認してみます。

estimator.latest_training_job.rule_job_summary()

学習開始直後はどちらのルールもInProgressステータスでした。

しばらく時間が経ってから再度確認してみると、学習が進み、2 つのルールにおいてIssuesFoundステータスになりました。

次に SageMaker Debugger のルールの評価ジョブのログを確認してみます。Estimatorに設定された情報からからログの URL を作成します。

def \_get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
"""Helper function to get the rule job name"""
return "{}-{}-{}".format(
training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
)

def \_get_cw_url_for_rule_job(rule_job_name, region):
return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(region, region, rule_job_name)

def get_rule_jobs_cw_urls(estimator):
region = boto3.Session().region_name
training_job = estimator.latest_training_job
training_job_name = training_job.describe()["TrainingJobName"]
rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]

    result={}
    for status in rule_eval_statuses:
        if status.get("RuleEvaluationJobArn", None) is not None:
            rule_job_name = _get_rule_job_name(training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"])
            result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(rule_job_name, region)
    return result

get_rule_jobs_cw_urls(estimator)

各ルールのログの URL が表示されたので、開いて確認してみます。 VanishingGradientルールの評価ログは次の通りでした。

次のログはLossNotDecreasingルールのものです。

どちらの場合も最後にルールに抵触し、例外が投げられています。

学習ジョブのステータスを確認します。


estimator.latest_training_job.describe()["TrainingJobStatus"]

学習ジョブは無事停止されたようです。これで今回の検証は完了です。作成した Lambda 関数は今後も動き続けることになります。必要なければ混乱の元になるので、関数を削除しましょう。

さいごに

『機械学習 on AWS Advent Calendar 2019』 17 日目、「Amazon SageMaker Debugger で学習の異常を検出し、学習ジョブを停止する」についてお伝えしました。学習に時間がかかる場合には学習の途中で問題を検出し自動でジョブを停止&通知する事で、実験のイテレーションを高速化できるかと思います。また、運用ワークフローにモデルの学習が組み込まれている場合の問題検出にも役立ちそうです。

参考