【小ネタ】[Amazon SageMaker] ハイパーパラメータ調整の結果を一覧してみました

2020.05.01

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMaker(以下、SageMaker)では、ハイパーパラメータ調整(以下、HPO)を使用することで、指定範囲のパラメータで多数のトレーニングを実行し、最適なモデルを得ることができます。

HPOについては、下記に詳しく紹介されています。

本稿は、組み込みアルゴリズムの物体検出(object-detection)で、HPOを使用した結果を一覧してみた記録です。

2 コンソール

HPOの結果は、コンソールから確認可能です。

SageMakerHyperparameter tuning jobsメニューから、該当するトレーニングを選択すると、Traning jobsのタブで、実行された各ジョブで、どのような精度だったかを一覧できます。

しかし、それぞれのジョブで、どのようなパラメータが使用されたのかを確認するためには、各ジョブを開いて、Hyperparametersまでスクロールする必要があります。

最適なジョブだけに興味があるのであれば、特に問題ないのですが、どのようなパラメータでどんな結果になったのかを確認するのは、ちょっと大変です。

3 コード

ということで、SDKから、ハイパーパラメータ調整トレーニングにアクセスして、各ジョブで利用されたパラメータを一覧するプログラムを作成してみました。

from boto3.session import Session

session = Session(profile_name='developer')
client = session.client('sagemaker')

# Job名
hyperParameterTuningJobName = 'test-003'

# 単一のJOBを表現するクラス
class Job():
def __init__(self, summarie):
self.__name = summarie["TrainingJobName"]
self.__status = summarie["TrainingJobStatus"]
self.__mAP = summarie["FinalHyperParameterTuningJobObjectiveMetric"]["Value"]
self.__params = summarie["TunedHyperParameters"]

def header(self):
params = ''
for key in self.__params.keys():
params += "\t{:>15}".format(key)

return "{:>25}\t{:>8}\t{}{}".format("name", "mAP", "", params)

def string(self):
#for param in self.__params:
params = ''
for key in self.__params.keys():
val = self.__params[key]
try:
val = float(val)
num = round(val, 4)
params += "\t{:7.3f}".format(num)
except:
params += "\t{:>10}".format(val)
status = ''
if(self.__status != 'Completed'):
status = self.__status

return "{:>25}\t{:1.3f}\t{:>8}\t{}".format(self.__name, self.__mAP, status, params)

def main():

jobs = []

# 全てのJOBを列挙する
max = 100 # 最大100までしか指定できない
response = client.list_training_jobs_for_hyper_parameter_tuning_job(
HyperParameterTuningJobName = hyperParameterTuningJobName,
MaxResults = max,
)
while(True):
for summarie in response["TrainingJobSummaries"]:
jobs.append(Job(summarie))

if('NextToken' in response):
response = client.list_training_jobs_for_hyper_parameter_tuning_job(
HyperParameterTuningJobName = hyperParameterTuningJobName,
MaxResults = max,
NextToken = response["NextToken"]
)
else:
break

print(jobs[0].header())
for job in jobs:
print(job.string())

main()

実行結果は、以下のようになります。

name mAP learning_rate mini_batch_size optimizer
test-003-010-580805be 0.000 Stopped 0.086 13.000 adam
test-003-009-b9f74f88 0.686 0.000 9.000 adam
test-003-008-a4a20894 0.244 0.006 12.000 rmsprop
test-003-007-ef5dc64e 0.663 0.000 9.000 adam
test-003-006-95de60e6 0.000 0.004 21.000 sgd
test-003-005-4bde65d8 0.000 Stopped 0.255 21.000 sgd
test-003-004-41aba8a7 0.000 Stopped 0.001 20.000 adadelta
test-003-003-e682fb45 0.979 0.005 21.000 sgd
test-003-002-c76f6290 0.000 0.000 15.000 sgd
test-003-001-adf15b9d 0.000 0.000 15.000 sgd

4 最後に

今回は、組み込みアルゴリズムの物体検出(object-detection)で、HPOを使用した結果を一覧してみました。このツールを頼りに、HPOの動作を、色々確認してみたいと思います。