SageMakerで「うまい棒検出モデル」を作ってみた

SageMakerで「うまい棒検出モデル」を作ってみた

Clock Icon2018.08.06

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

うまい棒好き「この写真にはうまい棒が何個写っている??」 AI「コンポタ味が3本、チーズ味が2本、めんたい味が1本写っています。」

このようなシーンを実現したい人がいるかどうかはわかりませんが、SageMakerの物体検出アルゴリズムを使って、このシーンを実現するための学習モデル(うまい棒検出モデル)を作ってみました。 少し長くなりますが「物体検出を手軽に始めたい方」は是非見ていってください。

目次

  • 商品画像取得
  • 画像増幅
  • アノテーション
  • S3へのアップロード
  • SageMakerで学習モデル構築/エンドポイント作成
  • 推論
  • エンドポイント削除

開発環境

  • macOS High Sierra 10.13.6
  • Python 3.6.3
  • Pillow 5.2.0
  • awscli 1.15

商品画像取得

3種類のうまい棒を検出します。そのために今回は40枚の画像を用意しました。

  • うまい棒が1〜2本写っている画像×10枚(角度や向きを変えたのも)/種類毎
  • 複数の種類のうまい棒が写っている画像×10枚

こんな画像です。

この画像をumaibo/inディレクトリに保存します。

umaibo $ tree
.
└── in
    ├── 0001.jpg
    ├── ・・
    └── 0040.jpg

画像ファイルは連番になっていた方が都合が良いので、なっていない場合は以下のコマンドで一括変換します。

ls *.jpg | awk '{ printf "mv %s %04d.jpg\n", $0, NR }' | sh

画像増幅

元画像が少ないため画像を増幅します。umaibo/outディレクトリを作成し、umaibo配下で以下のスクリプトを実行します。

import os
from PIL import Image, ImageFilter

def main():
    data_dir_path = './out/'
    data_dir_path_in = './in/'
    file_list = os.listdir('./in/')

    count = 1
    for file_name in file_list:
        root, ext = os.path.splitext(file_name)
        if ext == '.png' or '.jpeg' or '.jpg':
            img = Image.open(data_dir_path_in + '/' + file_name) 
            tmp = img.transpose(Image.FLIP_LEFT_RIGHT)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.transpose(Image.FLIP_TOP_BOTTOM)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.transpose(Image.ROTATE_90)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.transpose(Image.ROTATE_180)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.transpose(Image.ROTATE_270)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.rotate(15)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.rotate(75)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.rotate(135)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1          
            tmp = img.rotate(195)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1
            tmp = img.rotate(255)
            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
            count+=1

if __name__ == '__main__':
    main()

このスクリプトを実行するとumaibo/outディレクトリに元の画像をベースとした400枚の商品画像が作成されます。

umaibo $ tree
.
├── in
│   ├── 0001.jpg
│   ├── ・・
│   └── 0040.jpg
└── out
    ├── 0001.jpg
    ├── ・・
    └── 0400.jpg

参考

アノテーション

次に商品画像の”どこ”に”なに”があるのかラベル付けしていきます。今回はそのためのツールとしてVoTT(Visual Object Tagging Tool)を利用しました。

VoTTをインストールしアプリケーション起動後、umaibo/outディレクトリを指定します。そして、Labelsにconpotamentaicheeseを入力してContinueを選択します。

すると商品画像が表示されるので、”どこ”に”なに”があるかをラベル付けしていきます。

この作業を400回繰り返すとout.jsonが出力されます。(もっと簡単な方法知ってる方教えてください。)出力されたout.jsonはVoTTのフォーマットとなっているため、これをSageMakerが読み込めるフォーマットに変換します。umaibo/jsonディレクトリを作成し、umaibo配下で以下のスクリプトを実行します。

import json

file_name = './out.json'
class_list = {'conpota':0, 'mentai':1, 'cheese':2}

with open(file_name) as f:
    js = json.load(f)

    for k, v in js['frames'].items():

        k = int(k)
        line = {}
        line['file'] = '{0:04d}'.format(k+1) + '.jpg'
        line['image_size'] = [{
            'width':int(v[0]['width']),
            'height':int(v[0]['height']),
            'depth':3
        }]

        line['annotations'] = []

        for annotation in v:

            line['annotations'].append(
                {
                    'class_id':class_list[annotation['tags'][0]],
                    'top':int(annotation['y1']),
                    'left':int(annotation['x1']),
                    'width':int(annotation['x2'])-int(annotation['x1']),
                    'height':int(annotation['y2']-int(annotation['y1']))
                }
            )

        line['categories'] = []
        
        for name, class_id in class_list.items():

            line['categories'].append(
                {
                    'class_id':class_id,
                    'name':name
                }
            )

        f = open('./json/'+'{0:04d}'.format(k+1) + '.json', 'w')
        json.dump(line, f)

このスクリプトを実行するとjsonフォルダに400枚の商品画像に対するjsonファイルが作成されます。

umaibo $ tree
.
├── in
│   ├── 0001.jpg
│   ├── ・・
│   └── 0040.jpg
├── out
│   ├── 0001.jpg
│   ├── ・・
│   └── 0400.jpg
├── out.json
└── json
    ├── 0001.json
    ├── ・・
    └── 0400.json

S3へのアップロード

開発環境にS3アップロード用フォルダを作成します。そして商品画像を学習用データと検証用データに分けそれぞれのフォルダに移動します。

umaibo $ mkdir -p s3/DEMO-ObjectDetection/validation/ s3/DEMO-ObjectDetection/train/ s3/DEMO-ObjectDetection/validation_annotation/ s3/DEMO-ObjectDetection/train_annotation/
umaibo $ mv out/*1.jpg s3/DEMO-ObjectDetection/validation/
umaibo $ mv out/*.jpg s3/DEMO-ObjectDetection/train/
umaibo $ mv json/*1.json s3/DEMO-ObjectDetection/validation_annotation/
umaibo $ mv json/*.json s3/DEMO-ObjectDetection/train_annotation/

次にS3を作成し、画像をS3にアップロードします。S3の名前は任意のものに変更しておいてください。

umaibo $ aws s3 mb s3://hogehoge-bucket --region ap-northeast-1
umaibo $ aws s3 cp ./s3/ s3://hogehoge-bucket/ --recursive

これでSageMakerを実行のための事前準備完了です。

SageMakerで学習モデル構築/エンドポイント作成

SageMakerでノートブックインスタンスを起動します。まだ、SageMakerを利用したことがない方は、以下のステップ2を参考にノートブックインスタンスを作成/起動してください。「指定するS3バケット」には上記で作成したバケットを指定します。

次にJupyter notebook上で学習モデル構築のための処理を実施していきます。基本的には以下のサンプルより必要な箇所だけ抜き出して実行していきます。

まずはJupyter notebookを開きNewよりconda_mxnet_p36を作成します。

次に以下のコマンドを順に実行していきます。

セットアップ

%%time
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
print(role)
sess = sagemaker.Session()

バケット指定

作成したバケット名を指定します。

bucket = 'hogehoge-bucket' # custom bucket name.
# bucket = sess.default_bucket()
prefix = 'DEMO-ObjectDetection'

学習用イメージURL取得

from sagemaker.amazon.amazon_estimator import get_image_uri

training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")
print (training_image)

※イメージのバージョンがlatestなので、実行するタイミングによって結果が大きく変わる可能性があります。

入力データ設定

%%time

train_channel = prefix + '/train'
validation_channel = prefix + '/validation'
train_annotation_channel = prefix + '/train_annotation'
validation_annotation_channel = prefix + '/validation_annotation'

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)
s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel)
s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel)
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

アルゴリズム設定

od_model = sagemaker.estimator.Estimator(training_image,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.p3.2xlarge',
                                         train_volume_size = 50,
                                         train_max_run = 360000,
                                         input_mode = 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)

ハイパーパラメータ設定

以下を参考にハイパーパラメータを設定します。

od_model.set_hyperparameters(base_network='resnet-50',
                             use_pretrained_model=1,
                             num_classes=3,
                             mini_batch_size=16,
                             epochs=50,
                             learning_rate=0.001,
                             lr_scheduler_step='10',
                             lr_scheduler_factor=0.1,
                             optimizer='sgd',
                             momentum=0.9,
                             weight_decay=0.0005,
                             overlap_threshold=0.5,
                             nms_threshold=0.45,
                             image_shape=512,
                             label_width=350,
                             num_training_samples=360)

データチャネルとアルゴリズムの間でハンドシェイク

train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='image/jpeg', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')
train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')
validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')

data_channels = {'train': train_data, 'validation': validation_data, 
                 'train_annotation': train_annotation, 'validation_annotation':validation_annotation}

モデル作成

od_model.fit(inputs=data_channels, logs=True)

エンドポイント作成

object_detector = od_model.deploy(initial_instance_count = 1,
                                 instance_type = 'ml.m4.xlarge')

推論

最後に学習データではない画像を使って精度を確認します。検証したいデータをJupyter notebook上にアップロードします。

画像読み込み

画像を読み込みます。file_nameはアップロードしたファイル名に変更してください。

file_name = 'XXXX.jpg'

with open(file_name, 'rb') as image:
    f = image.read()
    b = bytearray(f)
    ne = open('n.txt','wb')
    ne.write(b)

推論

import json

object_detector.content_type = 'image/jpeg'
results = object_detector.predict(b)
detections = json.loads(results)
print (detections)

可視化関数作成

def visualize_detection(img_file, dets, classes=[], thresh=0.6):
        """
        visualize detections in one image
        Parameters:
        ----------
        img : numpy.array
            image, in bgr format
        dets : numpy.array
            ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
            each row is one object
        classes : tuple or list of str
            class names
        thresh : float
            score threshold
        """
        import random
        import matplotlib.pyplot as plt
        import matplotlib.image as mpimg

        img=mpimg.imread(img_file)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for det in dets['prediction']:
            (klass, score, x0, y0, x1, y1) = det
            if score < thresh:
                continue
            cls_id = int(klass)
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            xmin = int(x0 * width)
            ymin = int(y0 * height)
            xmax = int(x1 * width)
            ymax = int(y1 * height)
            rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                 ymax - ymin, fill=False,
                                 edgecolor=colors[cls_id],
                                 linewidth=3.5)
            plt.gca().add_patch(rect)
            class_name = str(cls_id)
            if classes and len(classes) > cls_id:
                class_name = classes[cls_id]
            plt.gca().text(xmin, ymin - 2,
                            '{:s} {:.3f}'.format(class_name, score),
                            bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                                    fontsize=12, color='white')
        plt.show()

可視化

object_categories = ['conpota', 'mentai', 'cheese']
# Setting a threshold 0.20 will only plot detection results that have a confidence score greater than 0.20.
threshold = 0.2

# Visualize the detections.
visualize_detection(file_name, detections, object_categories, threshold)

概ね正しい結果となっています!!

エンドポイント削除

確認が終わったらエンドポイントを忘れずに削除しておきましょう。

sagemaker.Session().delete_endpoint(object_detector.endpoint)

まとめ

SageMakerのビルトインアルゴリズムを利用すれば簡単に物体検出が始められます。アプリケーションエンジニアやインフラエンジニアで「物体検出でとりあえず動くものを作りたいなー」と思っている方は試してみてください。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.