Amazon SageMaker Debugger のカスタムルールを用いたTensorFlowモデルのデバッグ – 機械学習 on AWS Advent Calendar 2019 #reinvent

『機械学習 on AWS Advent Calendar 2019』の22日目

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

どうも、DA 事業本部の大澤です。

当エントリは『機械学習 on AWS Advent Calendar 2019』の 22 日目のエントリです。

今回は「Amazon SageMaker Debugger のカスタムルールを用いた TensorFlow モデルのデバッグ」についてご紹介します。

Amazon SageMaker Debugger

Amazon SageMaker Debugger はモデル学習中のデータを収集し、分析や問題の発見を可能にします。

構成要素

Amazon SageMaker Debuggerには次のような要素があります。

  • ステップ(Step): 学習が進む単位(バッチ単位)
  • コレクション(Collection): 学習実行中の各ステップの各種状態を表すテンソルの集まり
  • 例: weights、gradients、losses、full_shap、...
  • フック(Hook): ステップ単位に呼び出されるコールバックオブジェクトとして渡されるクラス
  • ルール(Rule): コレクションとして収集されるデータのモニタリングに用いる条件
  • 条件を満たした場合は例外が投げられる
  • 例: AllZeroVanishingGradientOverfitNLPSequenceRatio...
  • トライアル(Trial): 1つの学習もしくは、学習を分析するためのインターフェイス

これらのより詳しい解説についてはドキュメントを参照ください。

やってみる

Amazon SageMaker Examplesで紹介されている以下のノートブックを実際にやってみます。

次のような流れです。

  • カスタムルールを作成
  • 学習時の勾配が大きくなりすぎてないかチェックするルール
  • TensorFlow モデルの学習ジョブを作成
  • ルールを設定
  • ルール評価ジョブのログを CloudWatch Logs から確認

準備

まずはSageMaker Debugger用のライブラリをインストールします。

pip install smdebug

使用するモジュール/ライブラリの読み込みと学習用スクリプトのパスを定義します。

学習用スクリプトはtf_keras_resnet_zerocodechange.pyを使用します。画像分類用データセットであるCIFAR-10を用いて、Kerasで画像分類モデル(ResNet)を学習するという内容です。

import boto3
import os
import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker.debugger import Rule, DebuggerHookConfig, TensorBoardOutputConfig, CollectionConfig
import smdebug_rulesconfig as rule_configs

entrypoint_script='src/tf_keras_resnet_zerocodechange.py'

ルールの定義

今回はビルトインのルールではなく、自らルールを作成し、利用します。ルールはsmdebug.rules.rule.Ruleクラスを継承して定義します。

今回使用するルールを定義した次のスクリプトをrules/my_custom_rule.pyとして保存します。 学習時の各ステップでの勾配の平均があらかじめ指定した閾値を超えるかどうかを確認するという内容です。

rules/my_custom_rule.py

from smdebug.rules.rule import Rule

class CustomGradientRule(Rule):
    def __init__(self, base_trial, threshold=10.0):
        super().__init__(base_trial)
        self.threshold = float(threshold)

    def invoke_at_step(self, step):
        for tname in self.base_trial.tensor_names(collection="gradients"):
            t = self.base_trial.tensor(tname)
            abs_mean = t.reduction_value(step, "mean", abs=True)
            if abs_mean > self.threshold:
                return True
        return False

Ruleクラスのコンストラクタ

smdebug.rules.rule.Ruleクラスのコンストラクタの引数は次の通りです。

  • base_trial (Trial): 実行中のトライアルデータ
  • other_trials (list[Trial]): (Optional) 実行中以外のトライアルデータリスト
  • 他のトライアルデータと比較が必要となるルールで使用する
  • sagemaker.debugger.Rule.custom()other_trials_s3_input_pathsで指定した場所に保存されているトライアルデータが読み込まれる
  • **kwargs (str): ルールの評価に使用するパラメータ
  • 型は文字列(str)のみサポートされています
  • sagemaker.debugger.Rule.custom()rule_parametersでルールで使用する各パラメータの値を設定できます

invoke_at_step

invoke_at_stepsmdebug.rules.rule.Ruleクラスのメソッドです。 invoke_at_stepは現在のステップ番号stepを引数にとり、該当ステップでのコレクション/テンソルを参照し、条件を満たしているかをブール値で返します。Trueが返された場合には例外が投げられて、そのルールの評価ステータスはIssuesFoundとなります。

モデルの学習

モデルの学習時に評価するルールをsagemaker.debugger.Ruleを使って設定します。 ルール評価時に使用するコンテナイメージは事前に用意されているものを使用できます。コンテナイメージはリージョンごとにレジストリIDが違います。各リージョンのレジストリIDはドキュメントに記載されています。

custom_rule = Rule.custom(
    name='MyCustomRule', # ルールを特定するための名前
    # ルール評価に使用するコンテナイメージ
#     image_uri='759209512951.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rule-evaluator:latest', # us-west-2用
    image_uri='670969264625.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-debugger-rule-evaluator:latest', # ap-northeast-1用
    instance_type='ml.t3.medium', # ルール評価処理を実行するインスタンスタイプ
    source='rules/my_custom_rule.py', # ルールが定義されたスクリプトファイルのパス
    rule_to_invoke='CustomGradientRule', # 評価するルールのクラス名
    volume_size_in_gb=30, # 評価処理実行用インスタンスのEBSボリュームサイズ
    collections_to_save=[CollectionConfig("gradients")], # 学習中に収集、保存するコレクション(保存したコレクションはルール評価時に使用できる)
    rule_parameters={
      "threshold": "20.0" # ルールのコンストラクタで定義されたパラメータ
    }
)

学習に関するパラメータと合わせて、先ほど作成したルール情報をTensorFlow用のEstimatorに設定します。

estimator = TensorFlow(
    role=sagemaker.get_execution_role(),
    base_job_name='smdebug-custom-rule-demo-tf-keras',
    train_instance_count=1,
    train_instance_type='ml.p2.xlarge',
    entry_point=entrypoint_script, # 学習に使用するスクリプト
    framework_version='1.15',
    py_version='py3',
    train_max_run=3600,
    script_mode=True,
    rules = [custom_rule] # 学習時に評価するルール
)

# 学習ジョブとルールの評価ジョブを実行
estimator.fit(wait=False)

ルールの評価ジョブのステータスを確認します。

import time
status = estimator.latest_training_job.rule_job_summary()
while status[0]['RuleEvaluationStatus'] == 'InProgress':
    status = estimator.latest_training_job.rule_job_summary()
    print(status)
    time.sleep(10)

問題は見つからなかったようです。

Estimatorの情報を用いて、評価ジョブのログがあるCloudWatch LogsのURLを作成します。

def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
        """Helper function to get the rule job name with correct casing"""
        return "{}-{}-{}".format(
            training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
        )

def _get_cw_url_for_rule_job(rule_job_name, region):
    return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(region, region, rule_job_name)


def get_rule_jobs_cw_urls(estimator):
    training_job = estimator.latest_training_job
    training_job_name = training_job.describe()["TrainingJobName"]
    rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]

    result={}
    for status in rule_eval_statuses:
        if status.get("RuleEvaluationJobArn", None) is not None:
            rule_job_name = _get_rule_job_name(training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"])
            result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(rule_job_name, boto3.Session().region_name)
    return result

get_rule_jobs_cw_urls(estimator)

表示されたURLにアクセスし、評価ログを確認してみます。

学習が終了し、ステップデータが読み込まれ、評価処理が実行し、無事完了していることがわかります。

さいごに

『機械学習 on AWS Advent Calendar 2019』 22 日目、「Amazon SageMaker Debugger のカスタムルールを用いた TensorFlow モデルのデバッグ」についてお伝えしました。ルール定義のスクリプト自体がシンプルなので、評価ロジックさえ定まっていれば簡単に実装できそうです。事前に定義されたルールも数多くありますが、カスタムルールを用いることでより活用の幅が広がります。面白そうです。今後もAmazon SageMaker Debuggerの活用例を色々試して、紹介していきたいと思います。

参考