Amazon SageMaker Debuggerを試してみた #reinvent

2019.12.05

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

最初に

こんにちはデータアナリティクス事業本部のyoshimです。
re:Invent2019にて「Amazon SageMaker Debugger」のリリースが発表されたので、早速試してみました。

「Amazon SageMaker Debugger」はその名前の通り「自分で書いたモデルのデバッグ」の際に利用します。
自分でモデルを書いていると、色々と意図しないバグが残ってしまいがちですが、開発の最初から「Amazon SageMaker Debugger」を使っていると、バグを減らして、開発時間も短縮できるかもしれません。

(一部記述内容に誤りがあったため、2019年12月6日に修正しました)

目次

1.Amazon SageMaker Debuggerとは

最初に、「Amazon SageMaker Debugger」とはなんのか、についてですがざっくり言うと「トレーニング中のロスや精度等のスカラー値をモニタリングすることで、モデルの開発や評価をサポートするサービス」と言えそうです。

参照:debugger-how-it-works

具体的には、「指定したルール」で「トレーニング中のロスや精度等のスカラー値をモニタリング」し、「その結果からアラート(CloudWatch Event)をあげる」、「後から結果の詳細を確認する」といったことが可能です。
また、一般的なフレームワーク(TensorFlow, PyTorch, Apache MXNet, XGBoost)がサポートされており、既存のスクリプトの修正は不要です。
(「TensorFlow Horovod」を使って分散学習をさせたものについては、2019年12月5日時点ではサポートされていないようです。また、python2系はサポート対象外です)

より詳細についてはドキュメントをご参照いただきたいのですが、ざっくり機能の概要としては以上となります。

2.やってみた

ここからは実際に手を動かして理解を深めようと思います。
こちらの内容を実際に手を動かして確認してみました。

以下、スクリプトの内容を処理の段階ごとに区切って確認していきます。

2-1.用意されているスクリプト、デバッグのルールを使ってモデルの学習をする

サンプルとして用意されていたスクリプトやデータを使ってモデルの学習をします。
また、デバッグのルールについては「SageMaker側で事前に用意されているもの(15種類)」と「自前で定義するカスタムルール」の2パターンを利用できるのですが、今回はSageMaker側で事前に用意されているルールを利用しました。

デバッグをするために「entry_point」に指定するスクリプトを修正する必要はなく、下記の通りルールを指定してあげる数行を追加すればOKです。

from sagemaker.debugger import Rule, rule_configs
from sagemaker.tensorflow import TensorFlow
import sagemaker


# define the entrypoint script
entrypoint_script='src/mnist_zerocodechange.py'

# hyper parameter
# sampleのコードよりも増やしてます
hyperparameters = {
    "num_epochs": 12
}

# デバッグに適用するルールをList形式で指定
# SageMaker側で事前に用意されているルール、自前で用意するルール、の2パターンを利用可能
# 今回はSageMaker側で事前に用意されているルールを利用する
rules = [
    Rule.sagemaker(rule_configs.vanishing_gradient()), 
    Rule.sagemaker(rule_configs.loss_not_decreasing())
]

## estimatorに「rules」変数として指定する
estimator = TensorFlow(
    role=sagemaker.get_execution_role(),
    base_job_name='smdebugger-demo-mnist-tensorflow',
    train_instance_count=1,
    train_instance_type='ml.m5.12xlarge',
    train_volume_size=400,
    entry_point=entrypoint_script,
    framework_version='1.15',
    py_version='py3',
    train_max_run=3600,
    script_mode=True,
    hyperparameters=hyperparameters,
    ## New parameter
    rules = rules
)

estimator.fit(wait=True)

さて、これで「勾配消失」、「ロスが減らない」といった事象をチェックするルールを指定した状態で学習を実行できました。
(自前で定義したルールを利用する場合は、PythonのスクリプトをS3に格納してパスを指定する必要があります)

以上の形で学習を開始したところ、早速下記のようなログが出力されました。
これから検証する段階なのですが、「InProgress」という表記でした。

学習の終了直後、下記のようなログが出力されました。

ロスを確認したところ、確かにロスが下がらない場合がありました。
また、勾配消失はまだ発生していないので「InProgress」のままです。

2-2.ルールの検証結果を確認

学習が終わったので、今回指定したルールの検証結果を確認しましょう。
勾配消失は発生していないので「NoIssuesFound」、ロスが下がらない現象は発生していたので「IssuesFound」となっています。

estimator.latest_training_job.rule_job_summary()
[{'RuleConfigurationName': 'VanishingGradient',
  'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-east-1:123456789012:processing-job/smdebugger-demo-mnist-tens-vanishinggradient-f14fd901',
  'RuleEvaluationStatus': 'NoIssuesFound',
  'LastModifiedTime': datetime.datetime(2019, 12, 5, 22, 49, 12, 899000, tzinfo=tzlocal())},
 {'RuleConfigurationName': 'LossNotDecreasing',
  'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-east-1:123456789012:processing-job/smdebugger-demo-mnist-tens-lossnotdecreasing-33cbb36f',
  'RuleEvaluationStatus': 'IssuesFound',
  'StatusDetails': 'RuleEvaluationConditionMet: Evaluation of the rule LossNotDecreasing at step 4500 resulted in the condition being met\n',
  'LastModifiedTime': datetime.datetime(2019, 12, 5, 22, 49, 12, 899000, tzinfo=tzlocal())}]

また、ログがCloudWatchLogsに出力されているので、その参照先URLを出力して、確認できます。
何か問題があった場合はこのリンク先から詳細を確認できます。

def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
        """Helper function to get the rule job name with correct casing"""
        return "{}-{}-{}".format(
            training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
        )
    
def _get_cw_url_for_rule_job(rule_job_name, region):
    return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(region, region, rule_job_name)


def get_rule_jobs_cw_urls(estimator):
    region = boto3.Session().region_name
    training_job = estimator.latest_training_job
    training_job_name = training_job.describe()["TrainingJobName"]
    rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]
    
    result={}
    for status in rule_eval_statuses:
        if status.get("RuleEvaluationJobArn", None) is not None:
            rule_job_name = _get_rule_job_name(training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"])
            result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(rule_job_name, region)
    return result

get_rule_jobs_cw_urls(estimator)
{'VanishingGradient': 'https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/ProcessingJobs;prefix=smdebugger-demo-mnist-tens-VanishingGradient-f14fd901;streamFilter=typeLogStreamPrefix',
 'LossNotDecreasing': 'https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/ProcessingJobs;prefix=smdebugger-demo-mnist-tens-LossNotDecreasing-33cbb36f;streamFilter=typeLogStreamPrefix'}

2-3.探索的に確認する

また、より探索的に結果を確認することもできます。

from smdebug.trials import create_trial
trial = create_trial(estimator.latest_job_debugger_artifacts_path())

# 保持している全てのテンソルをリストとして表示する
# どれを可視化するか、というのを選ぶ際にまずこれで確認
print(trial.tensor_names())
%matplotlib inline

import matplotlib.pyplot as plt
import re

# 今回はロスを可視化する
trial.tensor_names(collection="losses")

# Define a function that, for the given tensor name, walks through all 
# the iterations for which we have data and fetches the value.
# Returns the set of steps and the values
def get_data(trial, tname):
    tensor = trial.tensor(tname)
    steps = tensor.steps()
    vals = [tensor.value(s) for s in steps]
    return steps, vals

def plot_tensors(trial, collection_name, ylabel=''):
    """
    Takes a `trial` and plots all tensors that match the given regex.
    """
    plt.figure(
        num=1, figsize=(8, 8), dpi=80,
        facecolor='w', edgecolor='k')

    tensors = trial.tensor_names(collection=collection_name)

    for tensor_name in sorted(tensors):
        steps, data = get_data(trial, tensor_name)
        plt.plot(steps, data, label=tensor_name)

    plt.legend(bbox_to_anchor=(1.04,1), loc='upper left')
    plt.xlabel('Iteration')
    plt.ylabel(ylabel)
    plt.show()
    
plot_tensors(trial, "losses", ylabel="Loss")

ロスが可視化できました。
確かに、ロスが下がっていないところがあるのが確認できます。

3.まとめ

自分でスクリプトを組んでいると、どうしても「今までの経験からの勘」とかで対応してしまいがちですが、こういった部分をツールがサポートしてくれるのは嬉しいですね。
スクリプトの変更も不要で、「Amazon SageMaker Debugger」が自前で用意しているルールを使うだけなら本当に簡単に適用できるので、開発の初期段階から積極的に使っていきたいですね。

また、今回はノートブックインスタンス上からノートブックファイルを実行したのですが、「Amazon SageMaker Studio」上から実行した時の見え方も今後確認してみたいと思います。

4.参考