この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
最初に
こんにちはデータアナリティクス事業本部のyoshimです。
re:Invent2019にて「Amazon SageMaker Debugger」のリリースが発表されたので、早速試してみました。
「Amazon SageMaker Debugger」はその名前の通り「自分で書いたモデルのデバッグ」の際に利用します。
自分でモデルを書いていると、色々と意図しないバグが残ってしまいがちですが、開発の最初から「Amazon SageMaker Debugger」を使っていると、バグを減らして、開発時間も短縮できるかもしれません。
(一部記述内容に誤りがあったため、2019年12月6日に修正しました)
目次
1.Amazon SageMaker Debuggerとは
最初に、「Amazon SageMaker Debugger」とはなんのか、についてですがざっくり言うと「トレーニング中のロスや精度等のスカラー値をモニタリングすることで、モデルの開発や評価をサポートするサービス」と言えそうです。
具体的には、「指定したルール」で「トレーニング中のロスや精度等のスカラー値をモニタリング」し、「その結果からアラート(CloudWatch Event)をあげる」、「後から結果の詳細を確認する」といったことが可能です。
また、一般的なフレームワーク(TensorFlow, PyTorch, Apache MXNet, XGBoost)がサポートされており、既存のスクリプトの修正は不要です。
(「TensorFlow Horovod」を使って分散学習をさせたものについては、2019年12月5日時点ではサポートされていないようです。また、python2系はサポート対象外です)
より詳細についてはドキュメントをご参照いただきたいのですが、ざっくり機能の概要としては以上となります。
2.やってみた
ここからは実際に手を動かして理解を深めようと思います。
こちらの内容を実際に手を動かして確認してみました。
以下、スクリプトの内容を処理の段階ごとに区切って確認していきます。
2-1.用意されているスクリプト、デバッグのルールを使ってモデルの学習をする
サンプルとして用意されていたスクリプトやデータを使ってモデルの学習をします。
また、デバッグのルールについては「SageMaker側で事前に用意されているもの(15種類)」と「自前で定義するカスタムルール」の2パターンを利用できるのですが、今回はSageMaker側で事前に用意されているルールを利用しました。
デバッグをするために「entry_point」に指定するスクリプトを修正する必要はなく、下記の通りルールを指定してあげる数行を追加すればOKです。
from sagemaker.debugger import Rule, rule_configs
from sagemaker.tensorflow import TensorFlow
import sagemaker
# define the entrypoint script
entrypoint_script='src/mnist_zerocodechange.py'
# hyper parameter
# sampleのコードよりも増やしてます
hyperparameters = {
"num_epochs": 12
}
# デバッグに適用するルールをList形式で指定
# SageMaker側で事前に用意されているルール、自前で用意するルール、の2パターンを利用可能
# 今回はSageMaker側で事前に用意されているルールを利用する
rules = [
Rule.sagemaker(rule_configs.vanishing_gradient()),
Rule.sagemaker(rule_configs.loss_not_decreasing())
]
## estimatorに「rules」変数として指定する
estimator = TensorFlow(
role=sagemaker.get_execution_role(),
base_job_name='smdebugger-demo-mnist-tensorflow',
train_instance_count=1,
train_instance_type='ml.m5.12xlarge',
train_volume_size=400,
entry_point=entrypoint_script,
framework_version='1.15',
py_version='py3',
train_max_run=3600,
script_mode=True,
hyperparameters=hyperparameters,
## New parameter
rules = rules
)
estimator.fit(wait=True)
さて、これで「勾配消失」、「ロスが減らない」といった事象をチェックするルールを指定した状態で学習を実行できました。
(自前で定義したルールを利用する場合は、PythonのスクリプトをS3に格納してパスを指定する必要があります)
以上の形で学習を開始したところ、早速下記のようなログが出力されました。
これから検証する段階なのですが、「InProgress」という表記でした。
学習の終了直後、下記のようなログが出力されました。
ロスを確認したところ、確かにロスが下がらない場合がありました。
また、勾配消失はまだ発生していないので「InProgress」のままです。
2-2.ルールの検証結果を確認
学習が終わったので、今回指定したルールの検証結果を確認しましょう。
勾配消失は発生していないので「NoIssuesFound」、ロスが下がらない現象は発生していたので「IssuesFound」となっています。
estimator.latest_training_job.rule_job_summary()
[{'RuleConfigurationName': 'VanishingGradient',
'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-east-1:123456789012:processing-job/smdebugger-demo-mnist-tens-vanishinggradient-f14fd901',
'RuleEvaluationStatus': 'NoIssuesFound',
'LastModifiedTime': datetime.datetime(2019, 12, 5, 22, 49, 12, 899000, tzinfo=tzlocal())},
{'RuleConfigurationName': 'LossNotDecreasing',
'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-east-1:123456789012:processing-job/smdebugger-demo-mnist-tens-lossnotdecreasing-33cbb36f',
'RuleEvaluationStatus': 'IssuesFound',
'StatusDetails': 'RuleEvaluationConditionMet: Evaluation of the rule LossNotDecreasing at step 4500 resulted in the condition being met\n',
'LastModifiedTime': datetime.datetime(2019, 12, 5, 22, 49, 12, 899000, tzinfo=tzlocal())}]
また、ログがCloudWatchLogsに出力されているので、その参照先URLを出力して、確認できます。
何か問題があった場合はこのリンク先から詳細を確認できます。
def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
"""Helper function to get the rule job name with correct casing"""
return "{}-{}-{}".format(
training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
)
def _get_cw_url_for_rule_job(rule_job_name, region):
return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(region, region, rule_job_name)
def get_rule_jobs_cw_urls(estimator):
region = boto3.Session().region_name
training_job = estimator.latest_training_job
training_job_name = training_job.describe()["TrainingJobName"]
rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]
result={}
for status in rule_eval_statuses:
if status.get("RuleEvaluationJobArn", None) is not None:
rule_job_name = _get_rule_job_name(training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"])
result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(rule_job_name, region)
return result
get_rule_jobs_cw_urls(estimator)
{'VanishingGradient': 'https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/ProcessingJobs;prefix=smdebugger-demo-mnist-tens-VanishingGradient-f14fd901;streamFilter=typeLogStreamPrefix',
'LossNotDecreasing': 'https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/ProcessingJobs;prefix=smdebugger-demo-mnist-tens-LossNotDecreasing-33cbb36f;streamFilter=typeLogStreamPrefix'}
2-3.探索的に確認する
また、より探索的に結果を確認することもできます。
from smdebug.trials import create_trial
trial = create_trial(estimator.latest_job_debugger_artifacts_path())
# 保持している全てのテンソルをリストとして表示する
# どれを可視化するか、というのを選ぶ際にまずこれで確認
print(trial.tensor_names())
%matplotlib inline
import matplotlib.pyplot as plt
import re
# 今回はロスを可視化する
trial.tensor_names(collection="losses")
# Define a function that, for the given tensor name, walks through all
# the iterations for which we have data and fetches the value.
# Returns the set of steps and the values
def get_data(trial, tname):
tensor = trial.tensor(tname)
steps = tensor.steps()
vals = [tensor.value(s) for s in steps]
return steps, vals
def plot_tensors(trial, collection_name, ylabel=''):
"""
Takes a `trial` and plots all tensors that match the given regex.
"""
plt.figure(
num=1, figsize=(8, 8), dpi=80,
facecolor='w', edgecolor='k')
tensors = trial.tensor_names(collection=collection_name)
for tensor_name in sorted(tensors):
steps, data = get_data(trial, tensor_name)
plt.plot(steps, data, label=tensor_name)
plt.legend(bbox_to_anchor=(1.04,1), loc='upper left')
plt.xlabel('Iteration')
plt.ylabel(ylabel)
plt.show()
plot_tensors(trial, "losses", ylabel="Loss")
ロスが可視化できました。
確かに、ロスが下がっていないところがあるのが確認できます。
3.まとめ
自分でスクリプトを組んでいると、どうしても「今までの経験からの勘」とかで対応してしまいがちですが、こういった部分をツールがサポートしてくれるのは嬉しいですね。
スクリプトの変更も不要で、「Amazon SageMaker Debugger」が自前で用意しているルールを使うだけなら本当に簡単に適用できるので、開発の初期段階から積極的に使っていきたいですね。
また、今回はノートブックインスタンス上からノートブックファイルを実行したのですが、「Amazon SageMaker Studio」上から実行した時の見え方も今後確認してみたいと思います。