[Amazon SageMaker] JumpStartでObjectDetectionモデル(SSD MobileNet 1.0)をFine-tuneしてみました。

2021.07.27

1 はじめに

IoT事業部の平内(SIN)です。

Amazon SageMake(以下、SageMaker)のJumpStartでは、TensorFlow HubPyTorch Hubのモデルをクリックのみでデプロイして利用できます。

また、一部のモデルでは、Fine-tuneによって、独自に用意したデータセットでモデルを再構築することも可能です。

これまで、Vision関連のモデルで Fine-tune に対応していたのは、Image Classification のものだけでしたが、今回、Object Detectionでも、対応しているモデルができたので、とりあえず試してみました。

2 データセット

使用したモデルはSSD MobileNet 1.0です。

モデルを選択すると、一番下にFine-tuneのためのデータセットについて記載があります。

ドキュメントによると、データセットは、S3バケットに配置し、その構造は以下のとおりです。

input_directory
    |--images
        |--abc.png
        |--def.png
    |--annotations.json

annotations.json

"images":[
    {
        "file_name": image_name,
        "height": height,
        "width": width,
        "id": image_id
    }
],
"annotations":[
    {
        "image_id": image_id,
        "bbox": [xmin, ymin, xmax, ymax],
        "category_id": bbox_label
    }
]

データをS3に配置している様子です。

なお、配置するS3バケットは、SageMaker Studioが動作するリージョンのものしか利用できないので注意が必要です。

3 学習

学習は、データセットを配置したバケットを指定するだけで開始できます。

学習用のインスタンスや、ハイパーパラメータは、デフォルトのまま作業を進めました。

4 デプロイ

学習が終わると、Deployのボタンが選択可能になります。

5 推論

画像をアップロードし推論してみました。

コードは以下の通りです。

import matplotlib.patches as patches
from matplotlib import pyplot as plt
from PIL import Image
from PIL import ImageColor

img_jpg = "test001.png"
with open(img_jpg, 'rb') as file: input_img = file.read()
best_results_per_input = query_endpoint(input_img)  

colors = ['#00ffff','#ff8800']
image_np = np.array(Image.open(img_jpg))
plt.figure(figsize=(20,20))
ax = plt.axes()
ax.imshow(image_np)
bboxes, classes, confidences = best_results_per_input
for idx in range(len(bboxes)):
    if(confidences[idx]>0.5):
        left, bot, right, top = bboxes[idx]
        x, w = [val * image_np.shape[1] for val in [left, right - left]]
        y, h = [val * image_np.shape[0] for val in [bot, top - bot]]
        color = colors[hash(classes[idx]) % len(colors)]
        rect = patches.Rectangle((x, y), w, h, linewidth=3, edgecolor=color, facecolor='none')
        ax.add_patch(rect)
        ax.text(x, y, "{} {:.0f}%".format(classes[idx], confidences[idx]*100), bbox=dict(facecolor='white', alpha=1))

6 データセットのコンバート

Amazon SageMaker Ground Truthで、物体検出用のデータセットを作成すると、アノテーションデータは、output.manifestというファイルに出力されますが、今回、これをannotations.jsonに変換して使用しました。

雑で恐縮ですが、参考に、変換したコードを置きます。

import json
class Converter:
    def __init__(self) -> None:
        self.__images = []
        self.__annotations = []

    @property
    def data(self):
        return {
            "images":self.__images,
            "annotations": self.__annotations
        }

    def append(self, line):
        dic = json.loads(line) 

        tmp = dic["source-ref"].split('/')
        file_name = tmp[len(tmp)-1]
        file_name = file_name.replace("jpg","png")
        
        boxlabel = dic["boxlabel"]
        size = boxlabel["image_size"][0]
        width = size['width']
        height = size['height']
        
        image_id = len(self.__images)
        self.__images.append({
                "file_name": file_name, 
                "height": height, 
                "width": width,
                "id": image_id
            }
        )
        
        for an in boxlabel["annotations"]:
            category_id = an["class_id"]
            width = an["width"]
            top = an["top"]
            height = an["height"]
            left = an["left"]
            self.__annotations.append(
                {
                    "image_id": image_id,
                    "bbox":[left, top, left+width, top+height],
                    "category_id": category_id
                }
            )
        pass

converter = Converter()

f = open('output.manifest', 'r')
line = f.readline()

while line:
    line = f.readline()
    if(line!=''):
        converter.append(line)
f.close()

with open('annotations.json', mode='wt', encoding='utf-8') as file:
  json.dump(converter.data, file)

7 Delete Endpoint

推論 用に立ち上げたエンドポイントは、自動的には止まりませんので、ご注意ください。

8 最後に

今回、JumpStartでObject DetectionのモデルをFine-tuneしてみました。今後、色々なモデルが対応されるのを楽しみにしています。

追伸:Epochs 3なのに、やけに学習時間がかかるなぁとログを見ていたら、OpenCVのコンパイルの時間でした。