[Amazon SageMaker] 画像合成によるデータセット作成時における背景の扱いについて

2021.01.23

1 はじめに

CX事業本部の平内(SIN)です。

下記では、冷蔵庫内を上からのアングルで見た映像で、商品を検出しています。

今回は、ここで作成したモデルについての紹介です。

最初に動画をご確認下さい。前段は、データセットを作成しているようす、そして、後段は、冷蔵庫内のカメラで商品を検出している場面です。

2 単一商品のデータセット

実は、当初、データ作成も容易なので、1枚の商品画像を背景と合成することでデータセットを作成しました。

しかし、このデータセットで作成したモデルは、商品の間隔が空いている時は、問題なく検出できるのですが

商品の間隔が狭くなると、途端に検出できなくなる問題がありました。

3 複数商品のデータセット

先のデータセットでは、常に商品の周りに背景が写っており、これを一緒に学習してしまっています。実際の場面で、商品の間隔が詰まった時、商品の周りの背景は確保されなくなるため、これが検出できない原因だと考えられます。

そこで、複数の商品を背景に合成してデータセットを作成するようにしました。 ランダムに商品画像を取り出し、完全に重ならないように配置して1つのデータとしています。

4 コード

1画像ずつ作成する段階で、データを保持するクラス(Data)に商品を追加する際、既に貼られている画像との重複度合いを確認し、一定以上の重複がある場合は、これを排除するようになっています。

また、データセットの画像に含まれる商品の分布がある程度平均化するように、Counterクラスでインデックスを調整しています。

"""
下記を参考にさせて頂きました。
https://github.com/aws-samples/smart-cooler/blob/master/_ml_model_package/synthetic-dataset.ipynb
"""

import json
import glob
import random
import os
import shutil
import math
import numpy as np
import cv2
from PIL import Image

MAX = 3000 # 生成する画像数

CLASS_NAME=["jagabee","chipstar","butamen","kyo_udon","koara","curry"]
COLORS = [(0,0,175),(175,0,0),(0,175,0),(175,175,0),(0,175,175),(175,175,175)]
SIZE=[150, 150, 150 ,185, 150, 185]

BACK_PATH = "./backgrounds"
PRODUCT_PATH = "./products"
OUTPUT_PATH = "./output"
S3Bucket = "s3://ground_truth_dataset"
manifestFile = "output.manifest"

def euler_to_mat(yaw, pitch, roll):
    c, s = math.cos(yaw), math.sin(yaw)
    M = np.matrix([[  c, 0.,  s], [ 0., 1., 0.], [ -s, 0.,  c]])

    c, s = math.cos(pitch), math.sin(pitch)
    M = np.matrix([[ 1., 0., 0.], [ 0.,  c, -s], [ 0.,  s,  c]]) * M

    c, s = math.cos(roll), math.sin(roll)
    M = np.matrix([[  c, -s, 0.], [  s,  c, 0.], [ 0., 0., 1.]]) * M
    return M

def make_affine_transform(from_shape, to_shape, 
                        min_scale, max_scale,
                        scale_variation=1.0,
                        rotation_variation=1.0,
                        translation_variation=1.0):

    from_size = np.array([[from_shape[1], from_shape[0]]]).T
    to_size = np.array([[to_shape[1], to_shape[0]]]).T

    scale = random.uniform((min_scale + max_scale) * 0.5 -
                            (max_scale - min_scale) * 0.5 * scale_variation,
                            (min_scale + max_scale) * 0.5 +
                            (max_scale - min_scale) * 0.5 * scale_variation)
    roll = random.uniform(-1.0, 1.0) * rotation_variation
    pitch = random.uniform(-0.15, 0.15) * rotation_variation
    yaw = random.uniform(-0.15, 0.15) * rotation_variation

    M = euler_to_mat(yaw, pitch, roll)[:2, :2]
    h = from_shape[0]
    w = from_shape[1]
    corners = np.matrix([[-w, +w, -w, +w],
                            [-h, -h, +h, +h]]) * 0.5
    skewed_size = np.array(np.max(M * corners, axis=1) -
                                np.min(M * corners, axis=1))

    scale *= np.min(to_size / skewed_size)

    trans = (np.random.random((2,1)) - 0.5) * translation_variation
    trans = ((2.0 * trans) ** 5.0) / 2.0
    trans = (to_size - skewed_size * scale) * trans

    center_to = to_size / 2.
    center_from = from_size / 2.

    M = euler_to_mat(yaw, pitch, roll)[:2, :2]
    M *= scale
    M = np.hstack([M, trans + center_to - M * center_from])
    return M

# アフィン変換
def transform(backImage, productImage, productSize):
    M = make_affine_transform(
                        from_shape=productImage.shape,
                        to_shape=backImage.shape,
                        min_scale=1.0,
                        max_scale=1.0,
                        rotation_variation=3.5,
                        # scale_variation=2.0,
                        scale_variation=1.0,
                        translation_variation=0.98)
    object_topleft = tuple(M.dot(np.array((0, 0) + (1, ))).tolist()[0])
    object_topright = tuple(M.dot(np.array((productSize, 0) + (1,))).tolist()[0])
    object_bottomleft = tuple(M.dot(np.array((0,productSize) + (1,))).tolist()[0])
    object_bottomright = tuple(M.dot(np.array((productSize, productSize) + (1,))).tolist()[0])

    object_tups = (object_topleft, object_topright, object_bottomleft, object_bottomright)
    object_xmin = (min(object_tups, key=lambda item:item[0])[0])
    object_xmax = (max(object_tups, key=lambda item:item[0])[0]) 
    object_ymin = (min(object_tups, key=lambda item:item[1])[1])
    object_ymax = (max(object_tups, key=lambda item:item[1])[1])
    rect = ((int(object_xmin),int(object_ymin)),(int(object_xmax),int(object_ymax)))    

    productImage =  cv2.warpAffine(productImage, M, (backImage.shape[1], backImage.shape[0]))
    return productImage, rect


# 背景と商品の合成
def margeImage(backImg, productImg):
    # PIL形式で重ねる
    back_pil = Image.fromarray(backImg)
    product_pil = Image.fromarray(productImg)
    back_pil.paste(product_pil, (0, 0), product_pil)
    return np.array(back_pil)

# エフェクト(Gauss)
def addGauss(img, level):
    return cv2.blur(img, (level * 2 + 1, level * 2 + 1))

# エフェクト(Noise)
def addNoiseSingleChannel(single):
    diff = 255 - single.max()
    noise = np.random.normal(0, random.randint(1, 100), single.shape)
    noise = (noise - noise.min())/(noise.max()-noise.min())
    noise= diff*noise
    noise= noise.astype(np.uint8)
    dst = single + noise
    return dst

# エフェクト(Noise)
def addNoise(img):
    img = img.astype('float64')
    img[:,:,0] = addNoiseSingleChannel(img[:,:,0])
    img[:,:,1] = addNoiseSingleChannel(img[:,:,1])
    img[:,:,2] = addNoiseSingleChannel(img[:,:,2])
    return img.astype('uint8')

# バウンディングボックス描画
def box(frame, rect, class_id):
    ((x1,y1),(x2,y2)) = rect
    label = "{}".format(CLASS_NAME[class_id])
    img = cv2.rectangle(frame,(x1, y1), (x2, y2), COLORS[class_id],2)
    img = cv2.rectangle(img,(x1, y1), (x1 + 150,y1-20), COLORS[class_id], -1)
    cv2.putText(img,label,(x1+2, y1-2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
    return img

# Manifest生成クラス
class Manifest:
    def __init__(self, class_name):
        self.__lines = ''
        self.__class_map={}
        for i in range(len(class_name)):
            self.__class_map[str(i)] = class_name[i]

    def appned(self, fileName, data, height, width):

        date = "0000-00-00T00:00:00.000000"
        line = {
            "source-ref": "{}/{}".format(S3Bucket, fileName),
            "boxlabel": {
                "image_size": [
                    {
                        "width": width,
                        "height": height,
                        "depth": 3
                    }
                ],
                "annotations": []
            },
            "boxlabel-metadata": {
                "job-name": "xxxxxxx",
                "class-map": self.__class_map,
                "human-annotated": "yes",
                "objects": {
                    "confidence": 1
                },
                "creation-date": date,
                "type": "groundtruth/object-detection"
            }
        }
        for i in range(data.max()):
            (_, rect, class_id) = data.get(i)
            ((x1,y1),(x2,y2)) = rect
            line["boxlabel"]["annotations"].append({
                "class_id": class_id,
                "width": x2 - x1,
                "top": y1,
                "height": y2 - y1,
                "left": x1
            })
        self.__lines += json.dumps(line) + '\n'

    def get(self):
        return self.__lines

# 背景画像生成クラス
class Backgrounds:
    def __init__(self, backPath):
        self.__backPath = backPath

    def get(self):
        imagePath = random.choice(glob.glob(self.__backPath + '/*.png'))
        return cv2.imread(imagePath, cv2.IMREAD_UNCHANGED) 

# 商品画像生成クラス
class Products:
    def __init__(self, productPath, class_name, size):
        self.__productPath = productPath
        self.__class_name = class_name
        self.__size = size

    def get(self, class_id):
        # 商品画像
        class_name = self.__class_name[class_id]
        image_path = random.choice(glob.glob(self.__productPath + '/' + class_name + '/*.png'))
        product_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) 

        # 商品画像のサイズ
        size = self.__size[class_id]
        return (self.__resize(product_image, size), size, class_id)

    # 商品画像のサイズ調整
    def __resize(self, img, size):
        org_h, org_w = img.shape[:2]
        imageArray = np.zeros((org_h, org_w, 4), np.uint8)
        img = cv2.resize(img, (size, size))
        imageArray[0:size, 0:size] = img
        return imageArray

# 1画像分のデータを保持するクラス
class Data:
    def __init__(self, rate):
        self.__rects = []
        self.__images = []
        self.__class_ids = []
        self.__rate = rate

    def get_class_ids(self):
        return self.__class_ids

    def max(self):
        return len(self.__rects)

    def get(self, i):
        return (self.__images[i], self.__rects[i], self.__class_ids[i])

    # 追加(重複率が指定値以上の場合は失敗する)
    def append(self, productImage, rect, class_id):
        conflict = False
        for i in range(len(self.__rects)):
            iou = self.__multiplicity(self.__rects[i], rect)
            if(iou > self.__rate):
                conflict = True
                break
        if(conflict == False):  
            self.__rects.append(rect)
            self.__images.append(productImage)
            self.__class_ids.append(class_id)
            return True
        return False

    # 重複率
    def __multiplicity(self, a, b):
        (ax_mn, ay_mn) = a[0]
        (ax_mx, ay_mx) = a[1]
        (bx_mn, by_mn) = b[0]
        (bx_mx, by_mx) = b[1]
        a_area = (ax_mx - ax_mn + 1) * (ay_mx - ay_mn + 1)
        b_area = (bx_mx - bx_mn + 1) * (by_mx - by_mn + 1)
        abx_mn = max(ax_mn, bx_mn)
        aby_mn = max(ay_mn, by_mn)
        abx_mx = min(ax_mx, bx_mx)
        aby_mx = min(ay_mx, by_mx)
        w = max(0, abx_mx - abx_mn + 1)
        h = max(0, aby_mx - aby_mn + 1)
        intersect = w*h
        return intersect / (a_area + b_area - intersect)

# 各クラスのデータ数が同一になるようにカウントする
class Counter():
    def __init__(self, max):
        self.__counter = np.zeros(max)

    def get(self):
        n = np.argmin(self.__counter)
        self.__counter[n] += 1
        return int(n)

    def print(self):
        print(self.__counter)

def main():

    # 出力先の初期化
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)
    os.mkdir(OUTPUT_PATH)

    backgrounds = Backgrounds(BACK_PATH)
    products = Products(PRODUCT_PATH, CLASS_NAME, SIZE)
    manifest = Manifest(CLASS_NAME)

    counter = Counter(len(CLASS_NAME))
    no = 0

    while(True):
        # 背景画像の取得
        backImage = backgrounds.get()

        # 商品データ
        data = Data(0.15)
        for _ in range(20):
            class_id = counter.get()
            # class_id = random.randint(0, len(CLASS_NAME)-1)
            # 商品画像の取得
            product_image, product_size, class_id  = products.get(class_id)
            # アフィン変換
            product_image, rect =  transform(backImage, product_image, product_size)
            # 商品の追加(重複した場合は、失敗する)
            data.append(product_image, rect, class_id)

        frame = backImage
        for index in range(data.max()):
            (product_image, _, _) = data.get(index)
            # 合成
            frame = margeImage(frame, product_image)

        # アルファチャンネル削除
        frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)

        # エフェクト(gauss)
        frame = addGauss(frame, random.randint(0, 2))

        # エフェクト(Noise)
        frame = addNoise(frame)

        # 画像名
        fileName = "{:04d}.png".format(no)
        no+=1

        # 画像保存            
        cv2.imwrite("{}/{}".format(OUTPUT_PATH, fileName), frame)
        # manifest追加
        manifest.appned(fileName, data, frame.shape[0], frame.shape[1])

        for i in range(data.max()):
            (_, rect, class_id) = data.get(i)
            # バウンディングボックス描画(確認用)
            frame = box(frame, rect, class_id)

        counter.print()
        if(MAX < no):
            break

        # 表示(確認用)
        cv2.imshow("frame", frame)
        cv2.waitKey(1)

    # manifest 保存
    with open('{}/{}'.format(OUTPUT_PATH, manifestFile), 'w') as f:
        f.write(manifest.get())

main()

5 RecordIO形式

上記で作成したAmazon SageMaker Ground Truth形式のデータをim2rec.pyを利用してRecordIO形式に変換しているコードです。

curl -O https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py

データセットに含まれる、商品の分布が予め平均化されているため、単純に学習用と検証用に分割しています。

import json
import os
import subprocess

# 定義
inputPath = '/tmp/input'
outputPath = '/tmp/output'
manifest = 'output.manifest'

CLASS_NAME=["jagabee","chipstar","butamen","kyo_udon","koara","curry"]

# 分割の比率
ratio = 0.8 # 8:2に分割する

# 1件のデータを表現するクラス
class Data():
    def __init__(self, src):
        # プロジェクト名の取得
        for key in src.keys():
            index = key.rfind("-metadata")
            if(index!=-1):
                projectName = key[0:index]

        # メタデータの取得
        metadata = src[projectName + '-metadata']
        class_map = metadata["class-map"]

        # 画像名の取得
        self.__imgFileName = os.path.basename(src["source-ref"])

        # 画像サイズの取得
        project = src[projectName]
        image_size = project["image_size"]
        self.__img_width = image_size[0]["width"]
        self.__img_height = image_size[0]["height"]

        self.__annotations = []
        # アノテーションの取得
        for annotation in project["annotations"]:
            class_id = annotation["class_id"]
            top = annotation["top"]
            left = annotation["left"]
            width = annotation["width"]
            height = annotation["height"]

            self.__annotations.append({
                "label": class_map[str(class_id)],
                "width": width,
                "top": top,
                "height": height,
                "left": left
            })

    @property
    def annotations(self):
        return self.__annotations

    # 指定されたラベルを含むかどうか
    def exsists(self, label):
        for annotation in self.__annotations:
            if(annotation["label"] == label):
                return True
        return False

    # .lstを生成して追加する
    def appendLst(self, lst, cls_list):

        index = len(lst.split('\n'))
        headerSize = 4 # hederSize,dataSize,imageWidth,imageHeight
        dataSize = 5
        str = "{}\t{}\t{}\t{}\t{}".format(index, headerSize, dataSize, self.__img_width, self.__img_height)

        for annotation in self.__annotations:
            cls_id = cls_list.index(annotation["label"])
            left = annotation["left"]
            right = left + annotation["width"]
            top = annotation["top"]
            bottom = top + annotation["height"]

            left = round(left / self.__img_width, 3)
            right = round(right / self.__img_width, 3)
            top = round(top / self.__img_height, 3)
            bottom = round(bottom / self.__img_height, 3)

            str += "\t{}\t{}\t{}\t{}\t{}".format(cls_id, left, top, right, bottom)          
        fileName = self.__imgFileName
        str += "\t{}".format(fileName)
        lst += str + "\n"
        return lst

# 全てのJSONデータを読み込む
def getDataList(inputPath, manifest):
    dataList = []
    with open("{}/{}".format(inputPath, manifest), 'r') as f:
        srcList = f.read().split('\n')
        for src in srcList:
            if(src != ''):
                json_src = json.loads(src)
                dataList.append(Data(json.loads(src)))
    return dataList

def main():

    # 出力先フォルダ生成
    os.makedirs(outputPath, exist_ok=True)

    # 全てのJSONデータを読み込む
    dataList = getDataList(inputPath, manifest)
    total_len = len(dataList)
    train_len = int(total_len * ratio)
    print("全データ: {}件 train: {} validation: {}".format(total_len, train_len, total_len-train_len))

    # .lst形式
    train = ''
    validation = ''

    for i in range(train_len):
        data = dataList.pop()
        train = data.appendLst(train, CLASS_NAME)

    for data in dataList:
        validation = data.appendLst(validation, CLASS_NAME)

    # .lstファイルの生成
    trainLst = "{}/train.lst".format(outputPath)
    validationLst = "{}/validation.lst".format(outputPath)
    with open(trainLst, mode='w') as f:
        f.write(train)
    with open(validationLst, mode='w') as f:
        f.write(validation)

    # im2rec.pyによるRecordIOファイル生成
    # python im2rec.py --pack-label <your_lst_file_name> <your_image_folder>
    im2rec = "{}/im2rec.py".format(os.getcwd())

    cmd = ["python3", im2rec, "--pack-label", "validation.lst", inputPath]
    result = subprocess.run(cmd, cwd=outputPath)
    print(result)

    cmd = ["python3", im2rec, "--pack-label", "train.lst", inputPath]
    result = subprocess.run(cmd, cwd=outputPath)
    print(result)

main()

6 学習

学習は、SageMakerの組み込みアルゴリズム(物体検出)を使用していますが、参考のため、使用したパラメータを残しておきます。

epochs = 10
num_classes = 6
num_training_samples = 2400
od_model.set_hyperparameters(base_network='resnet-50',
                             use_pretrained_model=1,
                             num_classes=num_classes,
                             mini_batch_size=32,
                             epochs=epochs,
                             learning_rate=0.001,
                             lr_scheduler_step='3,6',
                             lr_scheduler_factor=0.1,
                             optimizer='sgd',
                             momentum=0.9,
                             weight_decay=0.0005,
                             overlap_threshold=0.5,
                             nms_threshold=0.45,
                             image_shape=512,
                             label_width=350,
                             num_training_samples=num_training_samples)

7 最後に

今回は、合成によるデータセット作成での背景の扱いについての紹介でした。 合成も、実際にモデルを使用する場面で、どのような状況になるかを考えて行うことが重要だと思います。