この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
1 はじめに
CX事業本部の平内(SIN)です
Amazon SageMaker(以下、SageMaker)では、既存のモデルを元に学習を開始する増分学習がサポートされており、ここDevelopers.IOでも既に紹介されています。
上記は、コンソールから物体検出の増分学習の要領が、紹介されていますが、これを、単に、Jupyter Notebookでやってみた記録です。
Jupyter Notebookには、物体検出の増分学習のサンプルとして、Amazon SageMaker Object Detection Incremental Trainingがあり、データ形式がRecordIOとなっていますが、今回試したのは、JSON形式のデータセットです。
参考:Now easily perform incremental learning on Amazon SageMaker
2 Jupyter Notebook
Jupyter Notebookの内容は、以下の通りです。
(1) Setup
最初に、ロールの取得、データの入出力用S3バケット(プレフィックス)の定義、object-detectionのDockerイメージの取得を行います。これは、増分学習に限らず共通です。
%%time
import sagemaker
from sagemaker import get_execution_role
# Role
role = get_execution_role()
print(role)
sess = sagemaker.Session()
# S3
bucket = 'sagemaker-bucket'
prefix = 'my-sample'
# Iraning Image
from sagemaker.amazon.amazon_estimator import get_image_uri
training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")
print (training_image)
(2) Data Preparation
続いて、データセットの準備です。 学習データと検証データのS3バケットを設定していますします。こちらも、通常の学習と同様です。
import os
import urllib.request
# DataSet
train_channel = prefix + '/train'
validation_channel = prefix + '/validation'
train_annotation_channel = prefix + '/train_annotation'
validation_annotation_channel = prefix + '/validation_annotation'
s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)
s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel)
s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel)
train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated',
content_type='image/jpeg', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated',
content_type='image/jpeg', s3_data_type='S3Prefix')
train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution='FullyReplicated',
content_type='image/jpeg', s3_data_type='S3Prefix')
validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution='FullyReplicated',
content_type='image/jpeg', s3_data_type='S3Prefix')
こちらは、継承する元となるモデルの指定です。S3上に配置された、tar.gz形式のファイルを指定します。
# Model
s3_model_data = "s3://sagemaker-bucket/my-sample/output/model.tar.gz"
model_data = sagemaker.session.s3_input(s3_model_data, distribution='FullyReplicated',
content_type='application/x-sagemaker-model', s3_data_type='S3Prefix')
fit()のパラメータとなる、data_channelsを定義します。学習データ、検証データに併せて、modelという名前で、元となるモデルが指定されます。
# データチャンネル
data_channels = {'train': train_data, 'validation': validation_data, 'train_annotation': train_annotation, 'validation_annotation':validation_annotation,'model': model_data}
# 出力先
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)
(3) Traning
増分学習では、num_layers、image_shape、num_classesなどのネットワークを定義するハイパーパラメーターは、既存のモデルのトレーニングに使用されたものと同じである必要があります。
od_model = sagemaker.estimator.Estimator(training_image,
role,
train_instance_count=1,
train_instance_type='ml.p3.2xlarge',
train_volume_size = 50,
train_max_run = 360000,
input_mode = 'File',
output_path=s3_output_location,
sagemaker_session=sess)
od_model.set_hyperparameters(base_network='resnet-50',
#use_pretrained_model=1,
num_classes=3, ### label count ###
mini_batch_size=16,
epochs=10, ### epoch count ###
learning_rate=0.001,
lr_scheduler_step='10',
lr_scheduler_factor=0.1,
optimizer='sgd',
momentum=0.9,
weight_decay=0.0005,
overlap_threshold=0.5,
nms_threshold=0.45,
image_shape=512,
label_width=600,
num_training_samples=1808) ### data count ###
学習を開始します。
od_model.fit(inputs=data_channels, logs=True)
ログを見ると、epoch=0で既にscore=0.987となっており、既存のモデルを利用して学習を開始していることが分かります。
#quality_metric: host=algo-1, epoch=0, batch=113 train cross_entropy =(0.19252898495207674)
#quality_metric: host=algo-1, epoch=0, batch=113 train smooth_l1 =(0.0477017551396642)
#quality_metric: host=algo-1, epoch=0, validation mAP =(0.9876995446829199)
3 最後に
今回は、既存のモデルを使用して学習を追加出来るように、Jupyter Notebookを準備してみました。 正直な所、Epochを何回に指定して学習すれば良いのか、良く分かってないので、無駄に回しすぎたりしないように、増分学習で少しすづ確認しながら進めています。
すべてのコードは、下記に起きました。