[Amazon SageMaker] Amazon SageMaker Ground Truth で作成したデータを使用してオブジェクト検出でモデルを作成してみました

2020.04.13

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMaker の組み込みの物体検出アルゴリズム(object-detection)では、Amazon SageMaker Ground Truth(以下、Ground Truth)で作成したデータを使用することができます。

作業としては、Ground Truthで生成されたoutput.manifestを、学習用(train)と検証用(validation)に分割するだけです。

作成されたデータ内のラベルが、ある程度均一に、そして大量にある場合は、気にすることでは無いかも知れないでしょうが、今回は、ラベル数が少数で、単純に分割できない場合も考慮して、分割してみました。

2 Ground Truth

Ground Truthのデータは、下記で使用したものです。

分割の結果を確認しやすいように、あえて非常に少数のデータとなっています。

  • 画像数 50枚
  • ラベル(AHIRU×50 (DOG)×10

アウトプットは、s3://sagemaker-working-bucket-001/GroundTruth-output/AHIRU-Project/manifests/output/output.manifestになっています。

3 学習及び、検証データ

学習用及び、検証用に分割しているのは、以下のコード(Jupyter Notebook)です。

対象となる全てのデータから、含まれるラベルの数をカウントし、サンプル数の少ないラベルから順に、分割しています。

bucket_name = "sagemaker-working-bucket-001" 
prefix = 'object-detection-with-ground-truth'
#  Ground Truthの出力(output.manifest)をダウンロードする

inputManifestPath = 's3://sagemaker-working-bucket-001/GroundTruth-output/AHIRU-Project/manifests/output/output.manifest'
!aws s3 cp $inputManifestPath "./output.manifest"
# output.manifestをtrain及びvalidationに分割する

import json 

# 1件のデータの表現するクラス(定義されているラベルを把握するために使用する)
class Data():
    def __init__(self, src):
        self.src = src
        # プロジェクト名の取得
        for key in src.keys():
            index = key.rfind("-metadata")
            if(index!=-1):
                self.projectName = key[0:index]

        cls_map = src[self.projectName + "-metadata"]["class-map"]

        # アノテーション一覧からクラスIDの指定を取得する
        self.annotations = []
        for annotation in src[self.projectName]["annotations"]:
            id = annotation['class_id']
            self.annotations.append({
                "label":cls_map[str(id)]
            })

    # 指定されたラベルを含むかどうか
    def exsists(self, label):
        for annotation in self.annotations:
            if(annotation["label"] == label):
                return True
        return False

# 全てのJSONデータを読み込む
def getDataList(inputPath):
    dataList = []
    with open(inputPath, 'r') as f:
        srcList = f.read().split('\n')
        for src in srcList:
            if(src != ''):
                dataList.append(Data(json.loads(src)))
    return dataList

# ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)
def getLabel(dataList):
    labels = {}
    for data in dataList:
        for annotation in data.annotations:
            label = annotation["label"]
            if(label in labels):
                labels[label] += 1
            else:
                labels[label] = 1
    # ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)
    labels = sorted(labels.items(), key=lambda x:x[1])
    return labels

# dataListをラベルを含むものと、含まないものに分割する
def deviedDataList(dataList, label):
    targetList = []
    unTargetList = []
    for data in dataList:
        if(data.exsists(label)):
            targetList.append(data)
        else:
            unTargetList.append(data)
    return (targetList, unTargetList)


# 学習用と検証用の分割比率
ratio = 0.8  # 80%対、20%に分割する
# GroundTruthの出力
inputManifestFile = './output.manifest'
# SageMaker用の出力
outputTrainFile = './train'
outputValidationFile = './validation'

dataList = getDataList(inputManifestFile)
projectName = dataList[0].projectName
print("全データ: {}件 ".format(len(dataList)))

# ラベルの件数の少ない順に並べ替える(配列のインデックスが、クラスIDとなる)
labels = getLabel(dataList)
for i,label in enumerate(labels):
    print("[{}]{}: {}件 ".format(i, label[0], label[1]))

# 保存済みリスト
storedList = [] 

# 学習及び検証用の出力
train  = ''
validation = ''

# ラベルの数の少ないものから優先して分割する
for i,label in enumerate(labels):
    print("{} => ".format(label[0]))
    # dataListをラベルが含まれるものと、含まないものに分割する
    (targetList, unTargetList) = deviedDataList(dataList, label[0])
    # 保存済みリストから、当該ラベルで既に保存済の件数をカウントする
    (include, notInclude) = deviedDataList(storedList, label[0])
    storedCounst = len(include)
    # train用に必要な件数
    count = int(label[1] * ratio) - storedCounst
    print("train :{}".format(count))
    # train側への保存
    for i in range(count):
        data = targetList.pop()
        train += json.dumps(data.src) + '\n'
        storedList.append(data)
    # validation側への保存
    print("validation :{} ".format(len(targetList)))
    for data in targetList:
        validation += json.dumps(data.src) + '\n'
        storedList.append(data)

    dataList = unTargetList
    print("残り:{}件".format(len(dataList)))

with open(outputTrainFile, mode='w') as f:
    f.write(train)

with open(outputValidationFile, mode='w') as f:
    f.write(validation)

num_training_samples = len(train.split('\n'))
print("projectName:{} num_training_samples:{}".format(projectName, num_training_samples))    

コードを実行すると、下記の出力が得られます。

ラベルAHIRUとDOGは、それぞれ、30:10 8:2件に分けられています。 また、TraningJob作成時にパラメータとして使用される、プロジェクト名(AHIRU_Project)及び、学習データの件数(num_traning_samples:39)も、output.manifestから読み取られています。

4 S3へのアップロード

分割して作成された、2つのファイル(trainとvalidation)は、SageMakerから利用可能なように、S3にアップロードします。

# 分割した train及び、validationをS3にアップロードする
s3_train = "s3://{}/{}/train".format(bucket_name, prefix)
!aws s3 cp  $outputTrainFile $s3_train
s3_validation = "s3://{}/{}/validation".format(bucket_name, prefix)
!aws s3 cp $outputValidationFile $s3_validation

5 Traning

最初に、学習用のコンテナを準備します。

import boto3
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
sess = sagemaker.Session()

# 学習用のコンテナ取得
training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'object-detection', repo_version='latest')

# モデルを出力するバケットの指定
s3_output_path = "s3://{}/{}/output".format(bucket_name, prefix)

training_jobに設定するパラメータは、以下のとおりです。

作成した、train及び、validationは、InputDataConfig3DataType:AugmentedManifestFileとして渡されています。

epochsが、200と多いですが、データがごく少数なので、学習は、数分で完了します。

# パラメータの作成
import time
from time import gmtime, strftime

job_name_prefix = 'groundtruth-to-sagemaker'
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
job_name = job_name_prefix + timestamp

training_params = {
    "AlgorithmSpecification": {
        "TrainingImage": training_image, 
        "TrainingInputMode": "Pipe"
    },
    "RoleArn": role,
    "OutputDataConfig": {
        "S3OutputPath": s3_output_path
    },
    "ResourceConfig": {
        "InstanceCount": 1,   
        "InstanceType": "ml.p3.2xlarge",
        "VolumeSizeInGB": 50
    },
    "TrainingJobName": job_name,
    "HyperParameters": {
         "base_network": "resnet-50",
         "use_pretrained_model": "1",
         "num_classes": "2",
         "mini_batch_size": "10",
         "epochs": "200",
         "learning_rate": "0.001",
         "lr_scheduler_step": "3,6",
         "lr_scheduler_factor": "0.1",
         "optimizer": "rmsprop",
         "momentum": "0.9",
         "weight_decay": "0.0005",
         "overlap_threshold": "0.5",
         "nms_threshold": "0.45",
         "image_shape": "300",
         "label_width": "350",
         "num_training_samples": str(num_training_samples)
    },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 86400
    },
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile", 
                    "S3Uri": s3_train,
                    "S3DataDistributionType": "FullyReplicated",
                    "AttributeNames": ["source-ref",projectName]
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None"
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile", 
                    "S3Uri": s3_validation,
                    "S3DataDistributionType": "FullyReplicated",
                    "AttributeNames": ["source-ref",projectName]
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None"
        }
    ]
}

print('Training job name: {}'.format(job_name))
print('\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))

学習の開始は、以下です。

# Traning
client = boto3.client(service_name='sagemaker')
client.create_training_job(**training_params)

ジョブの状態が、Training job current status: Completedとなれは、学習は完了です。

# 学習ジョブの状態取得
status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))

トレーニングの状態は、AWSコンソールからも確認可能です。

ログは、CloudWatch Logsで確認でき、最終的に、mAP score は、0.51となていました。

quality_metric: host=algo-1, epoch=199, validation mAP <score>=(0.5106382978723404)
Updating the best model with validation-mAP=0.5106382978723404

6 最後に

今回は、Ground Truthで作成したデータを使用して、SageMakerのビルトイン(オブジェクト検出)を使用してみました。特に変換などは必要無いので、サンプル数が潤沢であれば、下記のように単純に分割するだけでも大丈夫でしょう。

with jsonlines.open('output.manifest','r') as reader:
    lines = list(reader)
    np.random.shuffle(lines) # ランダム化
# 全データ件数
dataset_size = len(lines)
# 学習用の件数
num_traning_samples = round(dataset_size*0.8)
# 分割
train_data = lines[:num_traning_sample]
validation_data = [lines[num_traning_samples:]

ごく少数のサンプルで不十分ですが、一応、動作していることを確認できました。

使用したNotebookは、下記に置きました。
GroundTruth_To_SageMaker.ipynb