Amazon SageMaker Debuggerのビルトインルールについて調べてみた

『機械学習 on AWS Advent Calendar 2019』の2日目

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

機械学習モデルを学習させる際には過学習や勾配消失等といった多種多様な問題の発生は避けられません。 Amazon SageMaker Debugger を使うことで、TensorFlow や MXNet、PyTorch、XGBoost を使ったモデル学習時の異常や問題を検出することができます。Amazon SageMaker Debuggerはルールにマッチしたかどうかによって評価し、学習時の異常や問題を検出します。

各ルールはトライアル(学習)の各種テンソルデータを参照し、ルールにマッチした場合にTrueを返します。ルールがTrueを返した場合には、ルールの評価処理によって例外が投げられます。CloudWatch EventsのSageMaker Training Job State Changeイベントを用いることで評価処理の結果をトリガーとして、通知などニアリアルタイムでの処理を実行することも可能です。

今回は Amazon SageMaker Debugger で事前に用意されているビルトインルール(Built-in Rules)にどういったものがあるか見てみます。

深層学習フレームワーク用

対応フレームワーク: TensorFlow/Apache MXNet/PyTorch

DeadRelu

ReLUが死んでる(dead)かどうかを検出します。 レイヤーの中で死んでいるReLUの割合が閾値以上の場合にTrueを返します。

パラメータ

tensor_regex, threshold_inactivity, threshold_layer

ExplodingTensor

発散を検出します。対象のテンソルにNaNやinfを検出した場合にTrueを返します。

パラメータ

collection_names, tensor_regex, only_nan

PoorWeightInitialization

学習の最初の数ステップの間の挙動から各重みの初期化が不十分かどうかを検出します。

  • レイヤーごとの重みの分散の最小値と最大値の比が閾値を越えるとTrueを返します
  • レイヤーごとの勾配分布の5パーセンタイルと95パーセンタイルの差の最小値が閾値以下ならTrueを返します
  • ロスが指定したステップの間減らなければTrueを返します

パラメータ

activation_inputs_regex, threshold, distribution_range, patience, steps

SaturatedActivation

アークタンジェントもしくはシグモイドを用いたアクティベーションレイヤーが飽和しているかどうかを検出します。 飽和しているアクティベーションノードの割合が閾値を超えている場合にTrueを返します

パラメータ

collection_names, tensor_regex, threshold_tanh, threshold_sigmoid, threshold_inactivity, threshold_layer

VanishingGradient

勾配の消失を検出します。勾配の絶対値の平均が閾値より低くなった場合にTrueを返します。

パラメータ

threshold

WeightUpdateRatio

重み更新時の変化率の異常を検出します。重み更新時の変化率が閾値として設定した最大値もしくは最小値を超えた場合にTrueを返します。

パラメータ

num_steps, large_threshold, small_threshold, epsilon

深層学習フレームワークとXGBoost用

対応フレームワーク: TensorFlow/Apache MXNet/PyTorch/XGBoost

AllZero

テンソルの値が0になっている割合が高過ぎないかを検出します。 テンソルの値の0の割合が閾値を超えた場合にTrueを返します。

パラメータ

collection_names, tensor_regex, threshold

ClassImbalance

分類モデルにおいて、クラスの偏りを検出します。

  • クラスごとのサンプル数が最大のものと最小のものの比が閾値を超えた場合にTrueを返します。
  • 各クラスで誤って推論したデータの割合が閾値を超えた場合にTrueを返します。
    • 例えば、Aというクラスでの誤推論の割合が閾値以下で、Bというクラスにおける誤推論の割合が閾値を超えた場合はTrueが返されます。

パラメータ

threshold_imbalance, threshold_misprediction, samples, argmax, labels_regex, predictions_regex

Confusion

分類モデルの混同行列の良さを評価し、問題を検出します。

  • 混同行列の対角要素/対角要素の合計が閾値より小さい場合にTrueを返します。
  • 混同行列の非対角要素/列要素の合計が閾値より大きい場合にTrueを返します。

パラメータ

category_no, labels, predictions, labels_collection, predictions_collection, min_diag, max_off_diag

LossNotDecreasing

ロスが減少しなくなったことを検出します。 指定したステップ間でのロスの減少割合が閾値より小さい場合にTrueを返します。

パラメータ

collection_names, tensor_regex, use_losses_collection, num_steps, diff_percent, mode

Overfit

学習データ(training)と検証データ(validation)のロスを比較することで、モデルの過剰適合(overfitting)を検出します。 学習データのロスの平均に対する、検証データのロスの平均学習データのロスの平均との差の割合が閾値を一定ステップ間超え続けた場合にTrueを返します。

パラメータ

tensor_regex, start_step, patience, ratio_threshold,

Overtraining

学習データ(training)や検証データ(validation)のロスの減少具合からモデルの過学習(overtraining)を検出します。

パラメータ

patience_train, patience_validation, delta

SimilarAcrossRuns

対象のトライアルと他のトライアルが似ているかどうかを検出します。

パラメータ

other_trial, collection_names, tensor_regex

TensorVariance

特定のテンソルの分散が高すぎたり、低すぎないかを検出します。

パラメータ

collection_names, tensor_regex, max_threshold, min_threshold

UnchangedTensor

テンソルがステップごとに変化がないことを検出します。テンソル同士の比較にはnumpy.allcloseが使われます。

パラメータ

collection_names, tensor_regex, num_steps, rtol, atol, equal_nan

深層学習アプリケーション用

CheckInputImages

サンプリングした入力画像の平均が0から閾値以上に離れているかどうかで、入力画像が正しく正規化されているかを検証します。

パラメータ

threshold_mean, threshold_samples, regex, channel

NLPSequenceRatio

自然言語処理において、特定のトークン(EOSやunknownなど)が入力トークンの中で占める割合が多すぎないかを検証します。

パラメータ

tensor_regex, token_values, token_thresholds_percent

XGBoost用

TreeDepth

学習によって作成された木の深さを測定します。

※ 具体的な記述がドキュメントにないので詳細は不明ですが、恐らく学習中の木の深さが閾値より浅いかどうかを検出するルールだと思います。

パラメータ

depth

さいごに

Amazon SageMaker Debugger の各種ビルトインルールの概要について紹介しました。用途に応じて適切なルールとパラメータを選択することで、モデルを学習する際に起こりうる様々な異常や問題を検出し、モデルの開発を効率化させることができそうです。