[Amazon SageMaker] オブジェクト検出のサンプルで使用されているデータセットを確認してみました。

2020.04.06

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX事業本部の平内(SIN)です。

Amazon SageMaker(以下、SageMaker)では、簡単に動作確認ができるように、サンプルコードが豊富に提供されています。

組み込みアルゴリズムである「オブジェクト検出」についても、利用するデータの形式に応じて、以下の3種類がありました。

  • object_detection_image_json_format.ipynb
  • object_detection_incremental_training.ipynb
  • object_detection_recordio_format.ipynb

このうち、「イメージ形式を使用してトレーニングする」のサンプルをそのまま(バケット名のみ設定)実行すると、下記のような結果を得ることができました。

※表示内容(信頼度)は、学習の結果によって変わります。
参考:Amazon SageMaker Object Detection using the Image and JSON format

今回は、このサンプルが、どのようなデータセットを使用して、このモデルを生成しているのかを確認してみました。

2 データセットのダウンロードと変換

使用されているデータセットは、MS COCOによる、オブジェクト検出、セグメンテーション、キャプション用の大規模なデータセットです。

(1) ダウンロード

サンプルでは、2017検証データセットをダウンロードして使用しています。

download('http://images.cocodataset.org/zips/val2017.zip')
download('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')

(2) 変換

ダウンロード・解凍後、instances_val2017.jsonを元に、イメージ形式で利用できる形のJSONを生成し、一旦、generatedフォルダに置かれます。 下記のログから、この時点でJSONの件数は、4952となっていることが分かります。

jsons = os.listdir('generated')
print ('There are {} images have annotation files'.format(len(jsons)))

There are 4952 images have annotation files

(3) 学習用と検証用に分割

generatedフォルダに置かれたJSONデータは、先頭から4452件までと、それ以降に分けられます。

train_jsons = jsons[:4452]
val_jsons = jsons[4452:]

また、それぞれ該当する画像と併せて、以下の4つのフォルダに移動されています。

  • train(画像[学習用])
  • train_annotation(ラベル[学習用])
  • validation(画像[検証用])
  • validation_annotation(ラベル[検証用])

3 データセットの内容

このデータセットには、多数のラベルで構成されていますが、内容を確認できるよう、画像と対応するラベルを表示する下記のような簡単なプログラグラムを作成してみました。

(1) すべてのデータ

import random
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import glob
import os

# カテゴリー名
categories = ['person', 'bicycle', 'car',  'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 
    'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
    'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
    'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable',
    'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
    'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush']

# 検索対象ディレクトリ
dirs = ["train","validation"]

for dir in dirs:
    for file in glob.glob("./{}/*.jpg".format(dir)):
        ##########################################################
        # # 画像
        ##########################################################
        img_file = file #'./train/{}.jpg'.format(name)
        # 画像の描画
        img = mpimg.imread(img_file)
        plt.imshow(img)

        ##########################################################
        # アノテーション
        ##########################################################
        # ターゲット名
        basename = os.path.splitext(os.path.basename(file))[0]

        json_file = './{}_annotation/{}.json'.format(dir, basename)
        json_open = open(json_file, 'r')
        json_load = json.load(json_open)

        colors = dict()
        # 各アノテーション取得
        for annotation in json_load["annotations"]:
            cls_id = int(annotation["class_id"])
            y = annotation["top"]
            x = annotation["left"]
            w = annotation["width"]
            h = annotation["height"]
            # ラベルの表示色
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            # ラベルの矩形描画
            rect = plt.Rectangle([x,y], w, h, fill=False, edgecolor=colors[cls_id], linewidth=3.5)
            plt.gca().add_patch(rect)
            # ラベル名の描画
            class_name = str(cls_id)
            if categories and len(categories) > cls_id:
                class_name = categories[cls_id]
                plt.gca().text(x, y - 2,
                    '{:s}'.format(class_name),
                    bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                    fontsize=12, color='white')
        # 表示            
        plt.show()

その一部です。

(2) motorbikeデータの列挙

ちょっと、ラベルが多岐に渡っていて、イメージしにくいので、最初に検出されていたmotorbikeが、どのようなデータになっているのかを確認してみました。

import json
import glob
import os

# オートバイのクラス番号
index = categories.index("motorbike")
# 検索対象ディレクトリ
dirs = ["train_annotation","validation_annotation"]

list = []
for dir in dirs:
    img_count = 0 # 画像数のカウンタ
    label_count = 0 # ラベル数のカウンタ
    for file in glob.glob("./{}/*.json".format(dir)):
        # JSONファイルを開く        
        json_open = open(file, 'r')
        json_load = json.load(json_open)
        find = False
        # 各アノテーションの列挙
        for annotation in json_load["annotations"]:
            cls_id = int(annotation["class_id"])
            # クラス番号が一致している場合のみカウントする
            if(cls_id == index):
                label_count+=1
                find = True
        # 当該画像でクラス一致が有った場合にカウントする
        if(find == True):
            list.append(os.path.splitext(os.path.basename(file))[0])
            img_count+=1
    # 結果表示
    print("[{}] image:{} label:{}".format(dir, img_count, label_count))

print(list)

結果、motorbikeが含まれているのは、以下の通り、学習用として、画像が147枚、ラベルが348件、検証用として、画像が12枚、ラベルが23件でした。

[train_annotation] image:147 label:348
[validation_annotation] image:12 label:23

motorbikeに絞って、データセットを列挙してみました。

list[]は、上のプログラムで得た出力であり、motorbikeがラベルされている画像の名称になります。

import random
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

list = ['000000411938', '000000408830', '000000491090', '000000007386', '000000491213', '000000284698', '000000119516', '000000396338', '000000011511', '000000579902', '000000426372', '000000454978', '000000356387', '000000230008', '000000224724', '000000141671', '000000102356', '000000149406', '000000037751', '000000321557', '000000534827', '000000048924', '000000214192', '000000578871', '000000513580', '000000553776', '000000054593', '000000524108', '000000481567', '000000338219', '000000345027', '000000480985', '000000291634', '000000430073', '000000194875', '000000270066', '000000177934', '000000499622', '000000069213', '000000427338', '000000140420', '000000376625', '000000312421', '000000527220', '000000206487', '000000433204', '000000490936', '000000488385', '000000399205', '000000459634', '000000147740', '000000165518', '000000007816', '000000144300', '000000158548', '000000350023', '000000192699', '000000226154', '000000299887', '000000102331', '000000178982', '000000380706', '000000308631', '000000017207', '000000394206', '000000139099', '000000467511', '000000226417', '000000018737', '000000153217', '000000551794', '000000293245', '000000446651', '000000031817', '000000085376', '000000181542', '000000246963', '000000013177', '000000142324', '000000457884', '000000472375', '000000226662', '000000213605', '000000102411', '000000068093', '000000423617', '000000480021', '000000011149', '000000441586', '000000572956', '000000266400', '000000145620', '000000227399', '000000210394', '000000146667', '000000276284', '000000410878', '000000081394', '000000135410', '000000256192', '000000078748', '000000226802', '000000457848', '000000254814', '000000279887', '000000578792', '000000019109', '000000114770', '000000230450', '000000571264', '000000568290', '000000203389', '000000460147', '000000338905', '000000143556', '000000402433', '000000463730', '000000190756', '000000273551', '000000292456', '000000394199', '000000538364', '000000357737', '000000534605', '000000336232', '000000461751', '000000314251', '000000142585', '000000343934', '000000136715', '000000259830', '000000306893', '000000507667', '000000179765', '000000152120', '000000463849', '000000291791', '000000297147', '000000165681', '000000456394', '000000122166', '000000044590', '000000033854', '000000008211', '000000245448', '000000125572', '000000244833', '000000574702', '000000363875', '000000070774', '000000204186', '000000255536', '000000455716', '000000546976', '000000296649', '000000038829', '000000455624', '000000498709', '000000462756']

# 表示対象のクラス
class_name = "motorbike"
index = categories.index(class_name)
color = (1, 0, 1)

for name in list:

    ##########################################################
    # # 画像
    ##########################################################
    img_file = './train/{}.jpg'.format(name)
    # 画像の描画
    img = mpimg.imread(img_file)
    plt.imshow(img)

    ##########################################################
    # アノテーション
    ##########################################################
    json_file = './train_annotation/{}.json'.format(name)
    json_open = open(json_file, 'r')
    json_load = json.load(json_open)

    # 各アノテーション取得
    for annotation in json_load["annotations"]:
        cls_id = int(annotation["class_id"])
        if(cls_id != index):
            continue
        y = annotation["top"]
        x = annotation["left"]
        w = annotation["width"]
        h = annotation["height"]
        # ラベルの矩形描画
        rect = plt.Rectangle([x,y], w, h, fill=False, edgecolor=color, linewidth=3.5)
        plt.gca().add_patch(rect)
        # ラベル名の描画
        plt.gca().text(x, y - 2,
            '{:s}'.format(class_name),
            bbox=dict(facecolor=color, alpha=0.5),
            fontsize=12, color='white')
    # 表示            
    plt.show()

結果の一部です。

4 ローカルでの検証

データセットに直接関係ないのですが、今回、モデルの作成までをJupyter Notebookで行い、推論は、ローカルから行いましたので、そこで使用したシェルです。

(1) エンドポイントの生成

model名を指定して実行することで、Configを作成し、エンドポイントを作成します。 configName及び、endPointNameは、事後のシェルでそのまま使えるように固定名称にしています。

from boto3.session import Session
import boto3

session = Session(profile_name='developer', region_name='ap-northeast-1')
client = session.client('sagemaker')

# model名
modelName = 'object-detection-2020-04-06-00-36-52-751'
# Endpoint configurations 
configName = 'sampleConfig'
# Endpoint
endPointName = 'sampleEndPoint'

response = client.create_endpoint_config(
    EndpointConfigName = configName,
    ProductionVariants=[
        {
            'VariantName': 'VariantName',
            'ModelName': modelName,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.m4.4xlarge'
        },
    ]
)

print(response)

response = client.create_endpoint(
    EndpointName=endPointName,
    EndpointConfigName=configName
)

print(response)

(2) 推論

サンプルで実行されている推論と同じ処理内容です。ローカルに有るtest.jpgを推論します。

import boto3
import json
from boto3.session import Session

session = Session(profile_name='developer')
client = session.client('sagemaker-runtime')

file_name = 'test.jpg'
endPoint = 'sampleEndPoint';

categories = ['person', 'bicycle', 'car',  'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 
    'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
    'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
    'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable',
    'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
    'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush']


def visualize_detection(img_file, dets, classes=[], thresh=0.6):
    import random
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg

    img=mpimg.imread(img_file)
    plt.imshow(img)
    height = img.shape[0]
    width = img.shape[1]
    colors = dict()
    for det in dets:
        (klass, score, x0, y0, x1, y1) = det
        print("{} {}".format(klass,score))
        if score < thresh:
            continue
        cls_id = int(klass)
        if cls_id not in colors:
            colors[cls_id] = (random.random(), random.random(), random.random())
        xmin = int(x0 * width)
        ymin = int(y0 * height)
        xmax = int(x1 * width)
        ymax = int(y1 * height)
        rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                ymax - ymin, fill=False,
                                edgecolor=colors[cls_id],
                                linewidth=3.5)
        plt.gca().add_patch(rect)
        class_name = str(cls_id)
        if classes and len(classes) > cls_id:
            class_name = classes[cls_id]
        plt.gca().text(xmin, ymin - 2,
                        '{:s} {:.3f}'.format(class_name, score),
                        bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                                fontsize=12, color='white')
        #plt.rcParams['figure.figsize'] = (50 ,50)
    plt.show()


with open(file_name, 'rb') as image:
    f = image.read()
    b = bytearray(f)

endpoint_response = client.invoke_endpoint(
    EndpointName=endPoint,
    Body=b,
    ContentType='image/jpeg'
)
results = endpoint_response['Body'].read()
detections = json.loads(results)

thresh = 0.2
visualize_detection(file_name, detections["prediction"], categories, thresh)

(3) エンドポイントの削除

固定名称とした、configNameendPointNameが、変更されていなければ、これを実行することでエンドポイントは削除されます。

from boto3.session import Session
import boto3

session = Session(profile_name='developer', region_name='ap-northeast-1')
client = session.client('sagemaker')

configName = 'sampleConfig'
endPointName = 'sampleEndPoint'


response = client.delete_endpoint(
    EndpointName=endPointName
)
print(response)

response = client.delete_endpoint_config(
    EndpointConfigName=configName
)
print(response)

5 最後に

何も気にしないで、サンプルを実行したときに、結構高い精度でオートバイが検出されているように感じたのですが、思ったより、オートバイのデータは、少ない印象でした。 単なるサンプルの実行ですが、どのようなデータセットで構築されているのかというイメージが少し湧いた気がしました。