Amazon SageMaker RLでCartPole(倒立振子)を強化学習してみる – Amazon SageMaker Advent Calendar 2018

こんにちは、大阪DI部の大澤です。 この記事は「クラスメソッド Amazon SageMaker Advent Calendar」の12日目の記事となります。

今回はAmazon SageMaker RLを使ってCartPole(倒立振子)で強化学習を試してみようと思います。

CartPole(倒立振子)

まずは今回扱う、CartPoleって何なのかということからです。CartPoleは強化学習の入門として一般的な問題です。

棒がカートの上に載っていて、棒が倒れないように、カートを左右に動かします。カートをどのように動かすのが良いのかを学習させるというものです。

強化学習

機械学習には教師あり学習と教師なし学習と強化学習という、大きく3つのカテゴリがあります。強化学習は教師(答え)を与えるわけでもなく、与えないという訳でもないと... じゃあ、どうやって学習するのかというと行動等に応じて報酬を与えます。区切りごとに与えられていく報酬の累積が最大になるように学習を進めていくものを強化学習といいます。
例えば、CartPoleの場合であれば0.5秒ごとに上を向いていれば1という報酬がもらえて、CartPoleが倒れるまでの累積報酬が最大になるように学習を進めるといった感じです。

強化学習について理解するのに良さそうな資料を参考資料で紹介しています。よければご参照ください。

やってみる

理解を進めるために実際にやってみます!

セットアップ

パッケージ/モジュールを読み込みます。

import sagemaker
import boto3
import sys
import os
import glob
import re
import numpy as np
import subprocess
from IPython.display import HTML
import time
from time import gmtime, strftime
sys.path.append("common")
from misc import get_execution_role, wait_for_s3_object
from sagemaker.rl import RLEstimator, RLToolkit, RLFramework

データを保存するS3の場所を設定しておきます。

sage_session = sagemaker.session.Session()

# 必要に応じて書き換える
s3_bucket = sage_session.default_bucket()
s3_output_path = 's3://{}/'.format(s3_bucket)
print("S3 bucket path: {}".format(s3_output_path))

学習ジョブの名前に付ける接頭辞を決めます。

# create unique job name
job_name_prefix = 'rl-cart-pole'

ローカルで学習を行うことができます。その場合はローカルにDockerコンテナを作成して、そこで学習を行うことになります。 今回はローカルではなく、SageMaker上で行いたいと思います。

# run in local mode?
local_mode = False

if local_mode:
    instance_type = 'local'
else:
    instance_type = "ml.m4.4xlarge"

モデルの学習やエンドポイントを作成する際に利用するIAMロールを取得します。

try:
    role = sagemaker.get_execution_role()
except:
    role = get_execution_role()

print("Using IAM role arn: {}".format(role))

ローカルモードであればsetup.shを実行してDocker等の設定を行います。このsetup.shはSageMakerのノートブックインスタンスにしか対応していないようなので注意が必要です。

# only run from SageMaker notebook instance
if local_mode:
    !/bin/bash ./common/setup.sh

学習用コードの確認

強化学習ではエージェントと環境という大きな2つの要素があります。エージェントは行動を選択し、学習する頭脳的なものです。環境はどういったものがあり、どういうことをするとどうなるみたいなことを定義したものです。

今回の環境には強化学習ツールキットであるGymCartPole-v0を使います。

  • カートは摩擦のない二次元の線の上にあり、棒がその上に載っています。
  • ステップごとに+1か-1の力をカートに与えることができます。
  • 各ステップで棒が上を向いていれば+1の報酬を得ます
  • 棒が垂直線から15度以上になるか、カートが中心から2.4個分離れると終了になります。

今回は強化学習のフレームワークであるCoachを使って強化学習を行います。エージェントや環境、可視化に関するパラメータに関して事前に定義されているプリセットpreset-cartpole-clippedppo.pyを使います。このファイルはハイパーパラメータのRLCOACH_PRESETに指定することで、学習時に読み込まれます。

以下のコードを実行することでプリセットファイルをシンタックスハイライト付きで表示します。

!pygmentize src/preset-cartpole-clippedppo.py

以下の画像ではグラフスケジュール(評価間隔などの学習全体の設定)の部分のみ紹介してますが、他にもエージェントや環境、可視化などの設定があります。

次は学習の最初に読み込まれるスクリプトファイル、train-coach.pyです。学習時にentry_pointとして指定します。

!pygmentize src/train-coach.py

学習

CartPoleの強化学習ジョブを実行します。 強化学習をハンドルするRLEstimatorに先ほど紹介したスクリプトファイルやインスタンスデータ、学習結果の出力先、ハイパーパラメータなどを指定します。

estimator = RLEstimator(entry_point="train-coach.py",
                        source_dir='src',
                        dependencies=["common/sagemaker_rl"],
                        toolkit=RLToolkit.COACH,
                        toolkit_version='0.11.0',
                        framework=RLFramework.MXNET,
                        role=role,
                        train_instance_type=instance_type,
                        train_instance_count=1,
                        output_path=s3_output_path,
                        base_job_name=job_name_prefix,
                        hyperparameters = {
                          "RLCOACH_PRESET": "preset-cartpole-clippedppo",
                          "rl.agent_params.algorithm.discount": 0.9,
                          "rl.evaluation_steps:EnvironmentEpisodes": 8,
                          "improve_steps": 10000,
                          "save_model": 1
                        }
                    )
# 学習を開始
estimator.fit(wait=local_mode)

学習は10分弱くらいで終了します。 マネジメントコンソールの学習ジョブの詳細ページからアルゴリズムメトリクス(学習過程における各メトリクスの変化)を表示できます。 エピソードごとの累積報酬量を1分ごとに平均を取ったものです。

学習結果の出力先の確認

以下のスクリプトで学習によって生成された中間(Intermediate)フォルダの場所を確認できます。中間フォルダには学習過程のgifや学習のメタデータが入っています。

job_name=estimator._current_job_name
print("Job name: {}".format(job_name))

s3_url = "s3://{}/{}".format(s3_bucket,job_name)

if local_mode:
    output_tar_key = "{}/output.tar.gz".format(job_name)
else:
    output_tar_key = "{}/output/output.tar.gz".format(job_name)

intermediate_folder_key = "{}/output/intermediate/".format(job_name)
output_url = "s3://{}/{}".format(s3_bucket, output_tar_key)
intermediate_url = "s3://{}/{}".format(s3_bucket, intermediate_folder_key)

print("S3 job path: {}".format(s3_url))
print("Output.tar.gz location: {}".format(output_url))
print("Intermediate folder path: {}".format(intermediate_url))

tmp_dir = "/tmp/{}".format(job_name)
os.system("mkdir {}".format(tmp_dir))
print("Create local folder {}".format(tmp_dir))

可視化

学習結果をダウンロードしてきて、各エピソードごとの累積報酬量の変化を見てみます。

%matplotlib inline
import pandas as pd

csv_file_name = "worker_0.simple_rl_graph.main_level.main_level.agent_0.csv"
key = os.path.join(intermediate_folder_key, csv_file_name)
wait_for_s3_object(s3_bucket, key, tmp_dir, training_job_name=job_name)

csv_file = "{}/{}".format(tmp_dir, csv_file_name)
df = pd.read_csv(csv_file)
df = df.dropna(subset=['Training Reward'])
x_axis = 'Episode #'
y_axis = 'Training Reward'

plt = df.plot(x=x_axis,y=y_axis, figsize=(12,5), legend=True, style='b-')
plt.set_ylabel(y_axis);
plt.set_xlabel(x_axis);

次に最後のエピソード(試行)の様子をgifで確認します。

key = os.path.join(intermediate_folder_key, 'gifs')
wait_for_s3_object(s3_bucket, key, tmp_dir, training_job_name=job_name)
print("Copied gifs files to {}".format(tmp_dir))

glob_pattern = os.path.join("{}/*.gif".format(tmp_dir))
gifs = [file for file in glob.iglob(glob_pattern, recursive=True)]
extract_episode = lambda string: int(re.search('.*episode-(\d*)_.*', string, re.IGNORECASE).group(1))
gifs.sort(key=extract_episode)
print("GIFs found:\n{}".format("\n".join([os.path.basename(gif) for gif in gifs])))

# visualize a specific episode
gif_index = -1 # 最後のgifをみるために-1にする
gif_filepath = gifs[gif_index]
gif_filename = os.path.basename(gif_filepath)
print("Selected GIF: {}".format(gif_filename))
os.system("mkdir -p ./src/tmp/ && cp {} ./src/tmp/{}.gif".format(gif_filepath, gif_filename))
HTML('<img src="./src/tmp/{}.gif">'.format(gif_filename))

gifをみることでどんな動きをしてるのかが分かりますし、いいですね。便利!

モデルの評価

先ほど学習して作成したモデルの評価を行います。評価とはいっても先ほど作成したモデルを読み込んで、続きから動かすというだけです。

チェックポイントファイルの読み込み

先ほどの学習のチェックポイントファイルをダウンロードします。


wait_for_s3_object(s3_bucket, output_tar_key, tmp_dir, training_job_name=job_name)

if not os.path.isfile("{}/output.tar.gz".format(tmp_dir)):
    raise FileNotFoundError("File output.tar.gz not found")
os.system("tar -xvzf {}/output.tar.gz -C {}".format(tmp_dir, tmp_dir))

if local_mode:
    checkpoint_dir = "{}/data/checkpoint".format(tmp_dir)
else:
    checkpoint_dir = "{}/checkpoint".format(tmp_dir)

print("Checkpoint directory {}".format(checkpoint_dir))

この後に行う評価とモデルのホスティングに向けてアップロードします。

if local_mode:
    checkpoint_path = 'file://{}'.format(checkpoint_dir)
    print("Local checkpoint file path: {}".format(checkpoint_path))
else:
    checkpoint_path = "s3://{}/{}/checkpoint/".format(s3_bucket, job_name)
    if not os.listdir(checkpoint_dir):
        raise FileNotFoundError("Checkpoint files not found under the path")
    os.system("aws s3 cp --recursive {} {}".format(checkpoint_dir, checkpoint_path))
    print("S3 checkpoint file path: {}".format(checkpoint_path))

estimator_eval = RLEstimator(role=role,
                             source_dir='src/',
                             dependencies=["common/sagemaker_rl"],
                             toolkit=RLToolkit.COACH,
                             toolkit_version='0.11.0',
                             framework=RLFramework.MXNET,
                             entry_point="evaluate-coach.py",
                             train_instance_count=1,
                             train_instance_type=instance_type,
                             hyperparameters = {
                                 "RLCOACH_PRESET": "preset-cartpole-clippedppo",
                                 "evaluate_steps": 2000
                             }
                            )

estimator_eval.fit({'checkpoint': checkpoint_path})

数分で処理は終わります。累積報酬は168程度あったようです。

モデルのデプロイ

エンドポイントの作成

エンドポイントを作成し、学習させたモデルをデプロイします。 entry_pointで推論時にモデルを読み込む処理等を書いたdeploy-mxnet-coach.pyを指定します。

predictor = estimator.deploy(initial_instance_count=1,
                             instance_type=instance_type,
                             entry_point='deploy-mxnet-coach.py')

推論

デプロイが完了したら、推論してみます。 渡すデータは[カートの位置, カートの速度, 棒の角度, 棒の速度]という形式です。

value, action = predictor.predict(np.array([0., 0., 2., 2.]))
action

先ほどと逆方向に向いている値を渡してみます。

value, action = predictor.predict(np.array([0., 0., -2., -2.]))
action

エンドポイントの削除

最後に余分な課金が発生しないようにエンドポイントを削除しましょう。

predictor.delete_endpoint()

さいごに

今回は 「クラスメソッド Amazon SageMaker Advent Calendar」 の12日目として、Amazon SageMaker RLを使ってCartPole問題に強化学習を試してみました。SageMaker Python SDKを使うことで、SageMakerでの教師あり学習や教師なし学習を行う際と似た感じで、強化学習を行うことができました。AWS RoboMakerAWS DeepRacerといったものも新しく出ましたし、これを機に強化学習を試してみてはいかがでしょうか。

お読みくださりありがとうございました〜!明日もお楽しみに〜!

参考資料