Amazon SageMakerで人の検出を試してみた

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、大澤です。

当エントリではAmazon SageMakerの組み込みアルゴリズムの1つ、「物体検出」についてご紹介していきたいと思います。 別のエントリと内容が被っていますが、当エントリではコードを追いながら紹介していきます。

「組み込みアルゴリズム」の解説については下記エントリをご参照ください。

目次

概要説明:物体検出とは

画像の中に含まれている「物体」を検出するためのアルゴリズムです。 似たアルゴリズムとして画像分類もありますが、以下のような違いがあります。

画像分類
画像全体に対して、学習した人や動物などを識別し、ラベル(分類名的なもの)を出力します。
物体検出
画像の中から人や動物などを複数検出することができ、検出された「物体」それぞれのラベルと位置を出力します。

組み込みアルゴリズム:物体検出の実践

Amazon SageMaker Examplesの物体検出の例に沿って進めていきます。

Pascal VOCのデータセットを利用して、いくつかの「物体」を学習させて、学習したモデルで画像から人とバイクを検出すると言う内容です。 データ量が多いので、全体的に時間がかかるので注意してください。 (ダウンロードするデータセットの容量が全部で約2.7GB位あります。)

ノートブックの作成

SageMakerのノートブックインスタンスを立ち上げて、 SageMaker Examples ↓ Introduction to Amazon algorithms ↓ object_detection_recordio_format.ipynb ↓ use でサンプルからノートブックをコピーして、開きます。 ノートブックインスタンスの作成についてはこちらをご参照ください。

環境変数とロールの確認

ロールを取得しておき、学習データ等を保存するS3のバケット名と保存オブジェクト名の接頭辞を決めます。

%%time
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
print(role)
sess = sagemaker.Session()
bucket = '<your_s3_bucket_name_here>' # custom bucket name.
# bucket = sess.default_bucket() 
prefix = 'DEMO-ObjectDetection'

学習用のコンテナイメージを取得します。

from sagemaker.amazon.amazon_estimator import get_image_uri

training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")
print (training_image)

データ取得

Pascal VOCのデータセットをダウンロードします。

%%time

# Download the dataset
!wget -P /tmp http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
!wget -P /tmp http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
!wget -P /tmp http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
# # Extract the data.
!tar -xf /tmp/VOCtrainval_11-May-2012.tar && rm /tmp/VOCtrainval_11-May-2012.tar
!tar -xf /tmp/VOCtrainval_06-Nov-2007.tar && rm /tmp/VOCtrainval_06-Nov-2007.tar
!tar -xf /tmp/VOCtest_06-Nov-2007.tar && rm /tmp/VOCtest_06-Nov-2007.tar

データをRecordIO形式に変換

MXNetが提供しているツールを使って、lst形式のデータを作成し、そこからim2recと言うスクリプトでRecordIO形式に変換します。 詳細はMXNetの例をご覧ください。

!python tools/prepare_dataset.py --dataset pascal --year 2007,2012 --set trainval --target VOCdevkit/train.lst
!rm -rf VOCdevkit/VOC2012
!python tools/prepare_dataset.py --dataset pascal --year 2007 --set test --target VOCdevkit/val.lst --no-shuffle
!rm -rf VOCdevkit/VOC2007

確認のために、lst形式のデータを3行だけ見てみます。

!head -n 3 VOCdevkit/train.lst > example.lst
f = open('example.lst','r')
lst_content = f.read()
print(lst_content)

ちゃんと表示されました。 今回使用するlst形式の説明を補足しておくと、行ごとに一つの画像に対するアノテーションを表しており、一つの行内にタブ区切りで値を複数持っています。行内の各値は次のような意味を持ちます。

画像インデックス    ヘッダサイズ    一つのオブジェクトのラベルデータのデータ数  ラベルデータ * 画像内の学習させたい物体の数 画像ファイル名
画像インデックス
各画像ごとに一意に定められたID
ヘッダサイズ
ヘッダのデータ数。今回の場合だと2で`ヘッダサイズ`と`一つのオブジェクトのラベルデータのデータ数`のことを指しています。
一つのオブジェクトのラベルデータのデータ数
画像内のラベルを表現するためのデータの数です。今回は`[class_index, xmin, ymin, xmax, ymax]`と言うラベルデータになるため、5になります。
ラベルデータ
画像内の学習させたい物体の場所を示したものです。物体毎に何というラベル(分類,class_index)で、どこのどういった形の矩形(xmin, ymin, xmax, ymax)の中に物体はあるのかを示したものです。画像内の学習させたい物体の数だけ定義が必要です。
画像ファイル名
対象の画像は何というファイル名なのかを示したものです。

S3へデータをアップロード

trainチャネルとvalidationチャネルのデータをS3へアップロードします。

%%time

# Upload the RecordIO files to train and validation channels
train_channel = prefix + '/train'
validation_channel = prefix + '/validation'

sess.upload_data(path='VOCdevkit/train.rec', bucket=bucket, key_prefix=train_channel)
sess.upload_data(path='VOCdevkit/val.rec', bucket=bucket, key_prefix=validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)

学習したモデルのアーティファクトの出力先を指定します。

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

学習

学習のために使用する、Estimatorを設定します。

od_model = sagemaker.estimator.Estimator(training_image,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.p3.2xlarge',
                                         train_volume_size = 50,
                                         train_max_run = 360000,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)

ハイパーパラメータを設定します。結構いっぱいあります。 ハイパーパラメータに関する詳細はドキュメントをご確認ください。

エポック数で学習データを何回使い回すかを設定できます。デフォルトでは1になっていましたが、10と30のケースを試してみました。

od_model.set_hyperparameters(base_network='resnet-50',
                             use_pretrained_model=1,
                             num_classes=20,
                             mini_batch_size=32,
                             epochs=10,
                             learning_rate=0.001,
                             lr_scheduler_step='3,6',
                             lr_scheduler_factor=0.1,
                             optimizer='sgd',
                             momentum=0.9,
                             weight_decay=0.0005,
                             overlap_threshold=0.5,
                             nms_threshold=0.45,
                             image_shape=300,
                             label_width=350,
                             num_training_samples=16551)

学習処理を実行します。 エポック数=10で30分弱程度、エポック数=30で80分程度かかりました。GPUインスタンスはそれなりにお金がかかるので、回しすぎには気をつけましょう。

od_model.fit(inputs=data_channels, logs=True)

マネジメントコンソールのSageMakerの該当するトレーニングジョブの詳細ページから対応するログを見にいけます。 学習が進むごとにスコアが変化していくのが分かり、面白いです。

モデルの展開

エンドポイントを作成し、先ほど学習したモデルを展開します。

object_detector = od_model.deploy(initial_instance_count = 1,
                                 instance_type = 'ml.m4.xlarge')

モデルの確認

確認のために、画像を一つだけ落としてきます。

!wget -O test.jpg https://images.pexels.com/photos/980382/pexels-photo-980382.jpeg
file_name = 'test.jpg'

with open(file_name, 'rb') as image:
    f = image.read()
    b = bytearray(f)
    ne = open('n.txt','wb')
    ne.write(b)

エンドポイントにデータを投げて検出結果を受け取って表示してみます。

import json

object_detector.content_type = 'image/jpeg'
results = object_detector.predict(b)
detections = json.loads(results)
print (detections)

文字だけ見ても何もわからないので、画像を表示して見ます。

def visualize_detection(img_file, dets, classes=[], thresh=0.6):
        """
        visualize detections in one image
        Parameters:
        ----------
        img : numpy.array
            image, in bgr format
        dets : numpy.array
            ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
            each row is one object
        classes : tuple or list of str
            class names
        thresh : float
            score threshold
        """
        import random
        import matplotlib.pyplot as plt
        import matplotlib.image as mpimg

        img=mpimg.imread(img_file)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for det in dets:
            (klass, score, x0, y0, x1, y1) = det
            if score < thresh:
                continue
            cls_id = int(klass)
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            xmin = int(x0 * width)
            ymin = int(y0 * height)
            xmax = int(x1 * width)
            ymax = int(y1 * height)
            rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                 ymax - ymin, fill=False,
                                 edgecolor=colors[cls_id],
                                 linewidth=3.5)
            plt.gca().add_patch(rect)
            class_name = str(cls_id)
            if classes and len(classes) > cls_id:
                class_name = classes[cls_id]
            plt.gca().text(xmin, ymin - 2,
                            '{:s} {:.3f}'.format(class_name, score),
                            bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                                    fontsize=12, color='white')
        plt.show()

では先ほど定義した関数を使って画像上にラベル等を描画します。

object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 
                     'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 
                     'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']

# Setting a threshold 0.20 will only plot detection results that have a confidence score greater than 0.20.
threshold = 0.20

# Visualize the detections.
visualize_detection(file_name, detections['prediction'], object_categories, threshold)

まずはエポック数=10の場合です。

すでにバイクと人がだいたい検出できてますが、少し人とバイクの検出が漏れています。 次はエポック数=30の場合です。

少し右端のバイクの検出が失敗していますが、それ以外はだいたい検出できました。

エンドポイントの削除

余分なお金を使わないように、エンドポイントを削除します。

sagemaker.Session().delete_endpoint(object_detector.endpoint)

まとめ

Amazon SageMakerの組み込みアルゴリズムの一つである物体検出を用いることで、画像から人とバイクの検出ができました。 色々な物の検出を行おうとすると、それなりに時間はかかりますが、例となるコードを実行していくだけで、簡単に検出ができるのは良いですね。 ただ、エポック数を増やしすぎたり、エンドポイントを放置といったことはしないように気をつけましょう。お金が消えていってしまいます。

以下シリーズではAmazon SageMakerのその他の組み込みアルゴリズムについても解説しています。宜しければ御覧ください。