MXNetのスクリプトでハイパーパラメータチューニングをする:Amazon SageMaker Advent Calendar 2018
概要
こんにちは、データインテグレーション部のyoshimです。
この記事は「クラスメソッド Amazon SageMaker Advent Calendar」の24日目の記事となります。
今回は「MXNet」で記述したスクリプトで、「ハイパーパラメータチューニング」をする方法について紹介します。また、今回紹介する記述方法は「MXNet」のバージョンが「1.2.0」,「1.3.0」の2パターンでの紹介になります。
目次
1.最初に
今回は、MXNetで記述したスクリプトをSageMakerから実行し、その際にSageMakerSDKのハイパーパラメータチューニング機能を使って、ハイパーパラメータのチューニングをします。MXNetのスクリプトを殆ど変更することなくチューニングできるので、とてもお手軽です。
参考にしたスクリプトはこちらです。
また、SageMaker上からMXNetのスクリプトを実行する際は「MXNetのフレームワークのバージョンが1.3.0以降か否か」で記述方法が変わるので、今回はフレームワークのバージョンが「1.2.0」,「1.3.0」の2パターンでの書き方を見てみようと思います。
(指定していない場合は、デフォルトで「1.2.0」で実行されます)
各フレームワークバージョンごとのエントリーポイントファイルの記述方法については、下記が参考になります。
1.2以前の書き方
1.3.0以降の場合の書き方
SageMakerでMXNetを使う方法
2.MXNetのフレームワークバージョンが「1.2」以前での記述
スクリプトを見る前に、一回エントリーポイントファイルを使う際のスクリプトの記述方法の概要について述べておきます。
- エントリーポイントとなるPythonファイルを用意しておく。
- エントリーポイントとなるPythonファイルには「train」という名前の関数を用意しておく必要がある。
- エントリーポイントファイルが実行されると、この「train」関数が実行される。この関数にはモデルの学習を実行するための処理を記述する。
- エントリーポイントを呼び出すスクリプトから、エントリーポイントファイルに変数を渡すことができる
以上から、とりあえず「エントリーポイントファイル」と「エントリーポイントファイルを呼び出すスクリプト」の2つが重要ということがわかるかと思います。 今回は「ハイパーパラメータチューニングをする」ことが目的なので、「ハイパーパラメータをどのようにエントリーポイントファイルに渡しているか」、について着目したいと思います。
とりあえず、「チュートリアル」の「エントリーポイントファイル」と「エントリーポイントファイルを呼び出すスクリプト」を順番に見ていきましょう。
2-1.エントリーポイントファイル
重要っぽいところについて説明します。
詳細についてはチュートリアルをご参照ください。
エントリーポイントファイルを呼び出した際は、下記の「train」関数が実行されています。
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
「train」関数以外の関数も全て実行はされるのですが、「train」関数を実行するために記述されたものです。
今回注目したい点は、この「train」関数に渡す引数はどこで指定されているか、という点です。
「hyperparameters」が正にハイパーパラメータチューニングで指定している部分なのですが、エントリーポイントファイル内には特に記述していません。
エントリーポイントの呼び出し元のファイルが怪しそうなので、一回そちらを確認してみましょう。
2-2.ノートブックファイル
エントリーポイントファイルを呼び出しているのは、hpo_mxnet_mnist.ipynbです。
具体的には、下記の部分ですね。
estimator = MXNet(entry_point='mnist.py', role=role, train_instance_count=1, train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker.Session(), base_job_name='DEMO-hpo-mxnet', hyperparameters={'batch_size': 100})
エントリーポイントとして「mnist.py」を指定しています。
そのあとは、このestimatorの「train」関数に対してハイパーパラメータの自動チューニング機能を使うために、下記のように処理を記述しています。
# チューニングするハイパーパラメータのレンジを指定 hyperparameter_ranges = {'optimizer': CategoricalParameter(['sgd', 'Adam']), 'learning_rate': ContinuousParameter(0.01, 0.2), 'num_epoch': IntegerParameter(10, 50)} # 評価指標 objective_metric_name = 'Validation-accuracy' metric_definitions = [{'Name': 'Validation-accuracy', 'Regex': 'Validation-accuracy=([0-9\\.]+)'}] # tunerの作成 tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, metric_definitions, max_jobs=9, max_parallel_jobs=3) # 実行 tuner.fit({'train': train_data_location, 'test': test_data_location})
ここで、先程確認した「エントリーポイントファイル」の「train」関数を思い出してください。
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
「current_host」、「channel_input_dirs」等、やはり指定していない引数が利用されています。
これらの引数はどこで指定しているのでしょうか?
2-3.エントリーポイントファイル実行時に渡される変数
一体どういうことなのかと悩んだのですが、「For versions 1.2 and lower」に答えが書いてありました。
どうやら、エントリーポイントのファイルを実行したタイミングで、いくつかの変数がエントリーポイントファイル内で利用できるようです。
ここまでのことを整理すると、「hyperparameter_ranges」にて指定したハイパーパラメータが「hyperparameters」変数としてエントリーファイル内に引き渡されて利用された、ということがわかりました。
(下記の「hyperparameters」についての説明をご参照下さい)
When you run your script on SageMaker via the MXNet Estimator, SageMaker injects information about the training environment into your training function via Python keyword arguments. You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:
hyperparameters (dict[string,string]): The hyperparameters passed to SageMaker TrainingJob that runs your MXNet training script. You can use this to pass hyperparameters to your training script.
input_data_config (dict[string,dict]): The SageMaker TrainingJob InputDataConfig object, that's set when the SageMaker TrainingJob is created. This is discussed in more detail below.
channel_input_dirs (dict[string,string]): A collection of directories containing training data. When you run training, you can partition your training data into different logical "channels". Depending on your problem, some common channel ideas are: "train", "test", "evaluation" or "images',"labels".
output_data_dir (str): A directory where your training script can write data that will be moved to s3 after training is complete.
num_gpus (int): The number of GPU devices available on your training instance.
num_cpus (int): The number of CPU devices available on your training instance.
hosts (list[str]): The list of host names running in the SageMaker Training Job cluster.
current_host (str): The name of the host executing the script. When you use SageMaker for MXNet training, the script is run on each host in the cluster.
3.MXNetのフレームワークバージョンが「1.3.0」以降での記述
さて、ではMXNetのフレームワークのバージョンが「1.3.0」以降になると何が変わるのでしょうか?
大きなところとしては、下記の2点を修正する必要があります。
それぞれ、順に見ていきましょう。
3-1.(if name=='main':)配下に色々と処理を記述する。
「1.3.0」以降では、「if name=='main':」配下に下記の3つの処理を実装する必要があります。
First, add a main guard (if name == 'main':). The code executed from your main guard needs to:
1.Set hyperparameters and directory locations
2.Initiate training
3.Save the model
引用:updating-your-mxnet-training-script
つまり、各種パラメータの引き渡し、学習、パラメータの保存、を記述する必要があるということですね。
学習処理については既存と変わりませんが、注意するのは「各種パラメータの引き渡し」と「パラメータの保存」の2点です。各パラメータの引き渡しについては後ほど説明しますが、argument parserを使うことが推奨されています。また、「パラメータの保存」については、下記のような処理を追加するだけでOKなパターンもありますが、場合によってはカスタムで処理を記述する必要があります。
パラメータを保存する関数を作成する際は「model」、「model_dir」の2つの引数を渡す必要がある点にご注意ください。
from sagemaker_mxnet_container.training_utils import save if __name__ == '__main__': # arg parsing (shown above) goes here model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test) save(args.model_dir, model)
引用:updating-your-mxnet-training-script
「model」、「model_dir」の2つの引数を渡す必要がある。
To provide your own save function, define a save function in your training script. The function should take two arguments:
model: This is the object that was returned from your train function. If your train function does not return an object, it will be None. You are free to return an object of any type from train, you do not have to return Module or Gluon API specific objects.
model_dir: This is the string path on the SageMaker training host where you save your model. Files created in this directory will be accessible in S3 after your SageMaker Training Job completes.
引用:writing-a-custom-save-function
3-2.各パラメータの引き渡し
今回参考にしたチュートリアルでは、MXNetのフレームワークバージョンが「1.2.0」だったために、エントリーポイントファイルに、「hyperparameters」引数に関する記述をしなくてもハイパーパラメータを受け取ることができていました。
しかし、「1.3.0」以降では、引数を受け取る処理をエントリーポイントファイル内に明示的に記述する必要があります。
Hyperparameters are passed to your script as arguments and can be retrieved with an argparse.ArgumentParser instance. For example, a training script might start with the following:
引用:For versions 1.3 and higher
この記述方法については、argument parserを使うことが推奨されており、またupdating-your-mxnet-training-scriptに具体的な記述方法が紹介されています。
import argparse import os if __name__ == '__main__': parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. parser.add_argument('--epochs', type=int, default=10) parser.add_argument('--batch-size', type=int, default=100) parser.add_argument('--learning-rate', type=float, default=0.1) # input data and model directories parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) args, _ = parser.parse_known_args()
このようにしてハイパーパラメータ等の変数を渡すことができます。
ここまでを「ハイパーパラメータをどのようにエントリーポイントファイルに渡しているか」といった観点から整理すると、エントリーポイントファイル実行時に「argparse」で明示的にハイパーパラメータを渡しているといったことがわかるかと思います。
最後に細かいのですが、エントリーポイントファイル実行時に参照できる「環境変数」についてはlist-of-provided-environment-variables-by-sagemaker-containersをご参照ください。
自分でMXNetでエントリーポイントファイルを記述する方には必須の知識となります。
4.「1.3.0」以降でのスクリプト例
ここまでの内容を反映して、チュートリアルと同じような処理を「MXNet 1.3.0」以降のバージョン用に書き直したスクリプトを紹介します。
スクリプト自体はエントリーポイントファイル,呼び出しファイルの通りですが、修正前との差分についてざっくり説明しようと思います。
4-1.エントリーポイントファイル
変数の受け取りや、パラメータの保存用にモジュールを追加しています。
import argparse from sagemaker_mxnet_container.training_utils import save import json
受け取ったハイパーパラメータも、以前は「hyperparameters」という辞書型の変数に受け取っていましたが、それぞれの引数として受け取っているので「train」関数に修正が入っています。
def train(current_host, train_dir, test_dir, hosts, num_cpus, num_gpus, optimizer, num_epoch, learning_rate, batch_size):
そして、最後に「if name =='main':」配下に「変数の受け取り」、「学習の実行」、「パラメータファイルの出力」を実行しています。
受け取るハイパーパラメータもそれぞれ別々に記述して受け取っています。
if __name__ =='__main__': parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. parser.add_argument('--optimizer', type=str) parser.add_argument('--num_epoch', type=int, default=100) parser.add_argument('--learning_rate', type=float, default=0.1) # input data and model directories parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) parser.add_argument('--batch_size', type=int) # system_param parser.add_argument('--num_gpus', type=int, default=os.environ['SM_NUM_GPUS']) parser.add_argument('--num_cpus', type=int, default=os.environ['SM_NUM_CPUS']) parser.add_argument('--current_host', type=str, default=os.environ['SM_CURRENT_HOST']) parser.add_argument('--hosts', type=str, default=json.loads(os.environ['SM_HOSTS'])) args, _ = parser.parse_known_args() model = train(args.current_host, args.train, args.test, args.hosts, args.num_cpus, args.num_gpus, args.optimizer, args.num_epoch,args.learning_rate,args.batch_size) save(args.model_dir, model)
4-2.呼び出しファイル
本筋とはずれてしまいますが、学習結果やソースコードを保存するS3パスを指定しています。 それ以外は特に変更ありません。
bucket = 's3://my-bucket' prefix = 'mxnet_hpo' model_artifacts_location = '{0}/{1}/model'.format(bucket,prefix) custom_code_upload_location = '{0}/{1}/code'.format(bucket,prefix) estimator = MXNet(# entry_point='mnist.py', entry_point='mnist_130.py', # after 「MXNet 1.3.0」 role=role, train_instance_count=1, train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker.Session(), output_path=model_artifacts_location, # add!!! code_location=custom_code_upload_location, # add! base_job_name='DEMO-hpo-mxnet', framework_version='1.3.0', hyperparameters={'batch_size': 100} )
5.まとめ
「MXNet 1.3.0」以降のフレームワークバージョンでのハイパーパラメータチューニング処理を実行するための記述内容について整理できました。
学習結果を推論用エンドポイントとして利用するためには、他にも「モデルをロードする関数」や「推論処理の関数」等を記述する必要があるのですが、とりあえず今回はここまでにしておこうかと思います。
MXNetを利用しようとしている方の参考になれば幸いです。