Amazon SageMakerで複数のネジの検出と分類をやってみた

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、大澤です。こちらのエントリで行われたネジ画像の分類を応用して、複数のネジの検出と分類をやってみました。

目次

概要

今回は「複数のネジが写った画像からネジを検出し、検出したネジの画像を画像分類器に入れることでネジの種類を特定する」ということをAmazon SageMakerを使ってやってみました。

物体検出でも画像分類のアルゴリズムを使っているので、物体検出だけでも実現できそうですが、画像分類を使った方が精度良く、細かい分類ができるのではないかという考えのもとで、試してみました。

やってみた

学習用データの作成

今回学習データに使う元となる画像はなべネジ、皿ネジ、蝶ネジがそれぞれ単体で写った画像3枚ずつ、3種類のネジが同時に写った画像が4枚の計13枚です。 そこからこちらのエントリと同じ方法で、画像増幅し、アノテーションを作成し、アノテーションをjson化するといった順序で学習用データを作成し、S3に保存します。

今回は13枚の画像から画像増幅によって130枚分の学習データを作成しました。

物体検出のモデルを学習

インスタンスの持っているロールの取得と、 学習データを保存しているバケット名とオブジェクトの接頭辞の設定を行います。

%%time
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()

sess = sagemaker.Session()
bucket = '<bucket-name>' # custom bucket name.
prefix = '<object-prefix>'

物体検出の学習用コンテナ名を取得します。

from sagemaker.amazon.amazon_estimator import get_image_uri

training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")

S3に保存しているファイルのパスを指定します。

%%time

train_channel = prefix + '/train'
validation_channel = prefix + '/validation'
train_annotation_channel = prefix + '/train_annotation'
validation_annotation_channel = prefix + '/validation_annotation'


s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)
s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel)
s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel)

学習時に出力されるアーティファクトの保存先を指定。

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

学習時に必要な設定を行う。

od_model = sagemaker.estimator.Estimator(training_image,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.p3.2xlarge',
                                         train_volume_size = 50,
                                         train_max_run = 1800,
                                         input_mode = 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)

ハイパーパラメータを設定します。 各ハイパーパラメータの内容についてはドキュメントをご確認ください。

od_model.set_hyperparameters(base_network='resnet-50',
                             use_pretrained_model=1,
                             num_classes=1,
                             mini_batch_size=10,
                             epochs=50,
                             learning_rate=0.001,
                             lr_scheduler_step='10',
                             lr_scheduler_factor=0.1,
                             optimizer='sgd',
                             momentum=0.9,
                             weight_decay=0.0005,
                             overlap_threshold=0.5,
                             nms_threshold=0.45,
                             image_shape=512,
                             label_width=350,
                             num_training_samples=117)

学習用データ、validation用データを指定。

train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='image/jpeg', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')
train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')
validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')

data_channels = {'train': train_data, 'validation': validation_data, 
                 'train_annotation': train_annotation, 'validation_annotation':validation_annotation}

学習!

od_model.fit(inputs=data_channels, logs=True)

下のような感じでログが出力されます。

モデルの展開

#先ほど学習したモデルを展開
object_detector = od_model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')

# 事前に学習してある画像分類のモデルを既存のエンドポイントコンフィグを使って展開
endpoint_name = "endpoint_name" 
endpoint_config_name = "endpoint_config_name"
screw_classifier_end = sess.create_endpoint(endpoint_name, endpoint_config_name)

エンドポイントへのリクエスト時のデータ型とレスポンスのデシリアライザーを指定します。

from sagemaker.predictor import json_deserializer
object_detector.content_type = 'image/jpeg'
object_detector.deserializer = json_deserializer

screw_classifier = sagemaker.predictor.RealTimePredictor(screw_classifier_end, sess, content_type='image/jpeg', deserializer=json_deserializer)

ネジの検出と分類

確認用の画像を読み込みます。


file_name = 'test.jpg'

with open(file_name, 'rb') as image:
    f = image.read()
    b = bytearray(f)
    ne = open('n.txt','wb')
    ne.write(b)

まずはネジの検出を行います。

detections = object_detector.predict(b)

ネジの分類と結果の可視化のための処理を作成します。

def resize_img(pil_img, width, height):
    """
    画像サイズを変更します。足りない領域は黒塗りされます。

    Parameters:
    ----------
    pil_img : PIL.Image
        screw image
    width : int
    height : int
    """
    pil_img.thumbnail((width, height))
    pil_img = pil_img.crop((0, 0, width, height))
    return pil_img

def classify_screw(pil_img, screw_classes, screw_classifier):
    """
    ネジを分類します
    Parameters:
    ----------
    pil_img : PIL.Image
        screw image
    screw_classes : tuple or list of str
        class names
    """
    import io
    import numpy as np
    from PIL import Image
    import time
    

    pil_img = resize_img(pil_img, 224, 224)
    pil_img.save("classify_img/"+ str(time.time()) + ".jpg")
    
    # byte形式に変換
    img_byte = io.BytesIO()
    pil_img.save(img_byte, format='JPEG')
    img_byte = img_byte.getvalue()
    
    # 種類判別をするために画像分類器に投げる
    response = screw_classifier.predict(img_byte)
    
    # 確率が一番高いものをその種類とする
    screw_id = np.argmax(response)
    screw_name = str(screw_id)
    if screw_id < len(screw_classes):
        screw_name = screw_classes[screw_id]
    return screw_name, response[screw_id]





def classify_and_visualize_detection(img_file, dets, object_classes, screw_classes, screw_classifier, thresh=0.6):
        """
        検出したものを分類し、画像に描画します
        Parameters:
        ----------
        img : numpy.array
            image, in bgr format
        dets : numpy.array
            ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
            each row is one object
        object_classes : tuple or list of str
            class names
        screw_classes : tuple or list of str
            class names
        thresh : float
            score threshold
        """
        import random
        import matplotlib.pyplot as plt
        import matplotlib.image as mpimg

        
        img=mpimg.imread(img_file)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        
        # load image for image classification
        from PIL import Image
        pilImg = Image.open(file_name)
        
        for i, det in enumerate(dets):
            (klass, score, x0, y0, x1, y1) = det
            if score < thresh:
                continue
            cls_id = int(klass)
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            xmin = int(x0 * width)
            ymin = int(y0 * height)
            xmax = int(x1 * width)
            ymax = int(y1 * height)
            

            class_name = str(cls_id)
            screw_prob = -1
            if object_classes and len(object_classes) > cls_id:
                class_name = object_classes[cls_id]
                
                # ネジの場合はネジの種類を分類
                if class_name == 'screw':
                    cropImg = pilImg.crop((xmin, ymin, xmax, ymax))
                    class_name, screw_prob = classify_screw(cropImg, screw_classes, screw_classifier)
            
            #  分類によって色を分ける
            if class_name not in colors:
                colors[class_name] = (random.random(), random.random(), random.random())

            # 枠を表示
            rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                 ymax - ymin, fill=False,
                                 edgecolor=colors[cls_id],
                                 linewidth=3.5)
            plt.gca().add_patch(rect)
            
            # 名称と確率を表示 (名称 該当する物体の種類に判定された確率 ネジの場合はその種類に分類された確率)
            label = '{:s} {:.3f}'.format(class_name, score)
            if screw_prob > -1:
                label += ' {:.3f}'.format(screw_prob)
                
            plt.gca().text(xmin, ymin - 2,
                            label,
                            bbox=dict(facecolor=colors[class_name], alpha=0.5),
                                    fontsize=12, color='white')
            
        plt.show()

ネジの分類と結果の描画を行います。

# 物体検出時のラベル
object_categories = ['screw']

# 画像分類時のラベル(ネジの種類)
screw_categories = ['saraneji', 'tyouneji', 'nabeneji']

# thresholdより低い判定率のものは表示しない
threshold = 0.4

# 検出結果を描画(内部でネジの種類の特定も行う)
classify_and_visualize_detection(file_name, detections['prediction'], object_categories, screw_categories, screw_classifier, threshold)

こんな感じで出力されます。 検出されたネジの周りを枠で覆い、枠の上にネジの種類 検出時のネジだという判定確率 分類時のその種類だとする確率を表示しています。

ネジの検出はできています。蝶ネジは分類されましたが、皿ネジとなべネジがうまく分けられてないですね。

もう一枚。

今回もネジの検出自体はうまくっていますが、皿ネジやなべネジまで蝶ネジだと判定されてしまっています。 画像分類がうまくっていないようです。

次はちょっとネジが固まっている場合です。

今度はネジの検出がうまくいっていません。

最後に一枚。

背景に少し柄があっても、ネジの検出自体はできました。ただ、蝶ネジにも関わらず皿ネジとして分類されています。

エンドポイントを削除

不要になったエンドポイントは削除しましょう。

sagemaker.Session().delete_endpoint(object_detector.endpoint)
sagemaker.Session().delete_endpoint(screw_classifier.endpoint)

まとめ

複数のネジの映った画像からネジの種類を判定することを目標に進めました。結果は次の通りです。

ネジの検出

概ね検出できたが、ネジが固まっていると正しく検出できなかった。

ネジが固まっている場合の画像を学習データに入れることで、改善できるかもしれません。

ネジの分類

一部は上手く分類できたが、多くの場合は正しく分類できなかった。

分類器の学習時に、同じ背景の画像のみを学習データとして使っていたので、「背景とネジと余白」がネジの特徴として学習されてしまったんではないかと思います。
結果、今回少し異なる背景であったり、ネジの周りの空間を検出時に切り取ったことによる余白の減少によって、分類が上手くいかなかったと考えられます。
ゆえに、分類器のモデルの学習データに異なる背景のネジや拡大したネジなどを含ませることで幾分かの改善を見込めると思います。

最後に

結果としては、成功とは言い難い結果でしたが、ネジの検出は意外と上手くいくことや分類器の欠点などが分かり、良い実験だったと言えます。

改善の余地はあると思うので、また試してみたいと思います。

参考