Amazon SageMakerで特定のGitリポジトリに含まれるTensorFlowモデルを学習させる

どうも、DA事業本部の大澤です。

SageMaker Python SDKのスクリプトを見ている時にFramework用Estimatorにgit_configなる引数があることを見つけました。 TensorFlowやMXNet等といったSageMakerがデフォルトで対応しているフレームワークに限られますが、git_configにリポジトリ情報を指定することで、そのリポジトリのスクリプトをエントリポイントとして指定し、学習に用いることができます。

SageMaker ExamplesにあるTensorFlow用のサンプルノートブックでgit_configが利用されていたので、試してみました。今回はその内容をお伝えします。

やってみる

amazon-sagemaker-examples/char-rnn-tensorflowでSherlock Holmesのテキストデータを学習させます。

※ サンプルノートブックに記載されている内容のうち、ローカルモードの部分は飛ばして、SageMakerで学習させる部分のみ紹介します

データ準備

まず学習に使用するSherlock Holmesのテキストデータをダウンロードします。

import os
data_dir = os.path.join(os.getcwd(), 'sherlock')

os.makedirs(data_dir, exist_ok=True)
!wget https://sherlock-holm.es/stories/plain-text/cnus.txt --force-directories --output-document=sherlock/input.txt

git_config

続いて学習に使用するスクリプトが含まれるリポジトリをクローンするためのgit_configを設定します。 git_configを設定することで、学習開始時に自動的に設定に従ってリポジトリをクローンし、source_dirに指定したディレクトリがS3にアップロードされ、学習に利用されます。 git_configには次のように辞書形式で設定を格納します。

git_config = {'repo': 'https://github.com/awslabs/amazon-sagemaker-examples.git', 'branch': 'training-scripts'}

git_configでは次のようなパラメータを利用できます。

名前 必須 デフォルト 説明 備考
repo o   リポジトリのURI HTTPS/SSHどちらでもOK
branch   master ブランチ名  
commit     コミットのハッシュ 未指定の場合は対象ブランチの最新コミットがクローンされる
2FA_enabled   false 2要素認証が有効かどうか HTTPSの場合のみ有効
username     ユーザ名 HTTPSの場合のみ有効
password     パスワード HTTPSの場合のみ有効
token     アクセストークン HTTPSの場合のみ有効

git_configで設定したリポジトリのクローンはローカル環境で実行されます。従って、SSH接続時の設定などはローカル環境に設定されているものが利用されます。

詳細な解説や使用例についてはドキュメントをご覧ください。

ハイパーパラメータ

学習に使用するはハイパーパラメータです。

hyperparameters = {'num_epochs': 1, 'data_dir': '/opt/ml/input/data/training'}

ハイパーパラメータとして設定した内容は学習実行時に次のようにスクリプトに渡されます。model_dirは設定の必要がなく、自動的に付与されます。

python train.py --num-epochs 1 --data_dir /opt/ml/input/data/training --model_dir /opt/ml/model

データアップロード

先ほどダウンロードしたデータをS3にアップロードします。inputsにはアップロード先のS3 URIが格納されます。

sagemaker.Session.upload_dataでバケット名を指定しない場合、sagemaker-{region}-{accountid}という名前のバケットが自動作成されて利用されます。

import sagemaker

inputs = sagemaker.Session().upload_data(path='sherlock', key_prefix='datasets/sherlock')

学習

TensorFlow用のEstimatorを用いて、学習の設定を行います。 先ほど設定したgit_confighyperparametersに加えて、使用するインスタンスや使用するIAMロールなども設定します。必要に応じて設定内容は変更してください。 今回はgit_configを設定するので、source_dirは相対パスで設定する必要があります。次のような設定内容であればgit_configに設定したリポジトリのディレクトリchar-rnn-tensorflowが圧縮されてS3にアップロードされます。

引数や詳細な使い方についてはドキュメントをご覧ください。

estimator = TensorFlow(entry_point='train.py',
                       source_dir='char-rnn-tensorflow',
                       git_config=git_config,
                       train_instance_type='ml.c4.xlarge', # Executes training in a ml.c4.xlarge instance
                       train_instance_count=1,
                       hyperparameters=hyperparameters,
                       role=sagemaker.get_execution_role(),  # 必要に応じて使用するIAMロールのARNを記載する
                       framework_version='1.14',
                       py_version='py3',
                       script_mode=True)
             

estimator.fit({'training': inputs})

次のように学習が開始されます。(リポジトリのクローン時のログは表示されません。)

...

S3にアップロードされた圧縮ファイルのパスは次のようにハイパーパラメータとして自動的に設定されているため、後から確認可能です。

estimator.hyperparameters()['sagemaker_submit_directory']

さいごに

SageMaker Python SDKのFramework用Estimatorのgit_configを使ってみた様子を紹介しました。今回はTensorFlowでの使い方を紹介しましたが、MXNetやPyTorchなどSageMakerがデフォルトで対応している他のフレームワークでも同様に使用することができます。オープンソースのリポジトリや別リポジトリで管理している学習用スクリプトを利用する場合には便利そうです。