話題の記事

機械学習でスーパーマーケットの缶ビールを検出してみました 〜Segment Anythingでセグメンテーションして、YOLOv8の分類モデルで銘柄を判定〜

2023.09.13

1 はじめに

CX 事業本部 delivery部の平内(SIN)です。

Meta社による Segment Anything Model(以下、SAM)は、セグメンテーションのための汎用モデルで、ファインチューニングなしで、あらゆる物体がセグメンテーションできます。

今回は、こちらを使用して、スーパーマーケットで冷蔵庫の缶ビールを検出してみました。

最初に、動作確認している様子です。

缶ビールがセグメンテーションされていることと、それぞれの銘柄が判定できていることを確認できると思います。

2 学習済みモデルによる物体検出 (YOLOv8)

YOLOv8などの物体検出モデルでは、配布されている学習済みモデルで、ある程度の物体検出が可能です。

下記は、テーブルに置かれた、缶ビールが、chairdining tableとともに、cupとして検出されている様子が確認できます。

しかし、このような学習済みモデルでは、店舗に陳列されている缶ビールを検出するようなことは、出来ませんでした。

3 フィルタによる抽出 (Segment Anything Model)

そこで、使用したのが、SAMです。

パラメータにもよりますが、SAMで検出してみると。なんとなく、缶ビールが検出できそうな予感がしてきます。

上記では、缶ビールの検出もできているようですが、缶上のラベルや、値札などもセグメンテーションされていて、ちょっと、欲しいものとは違う感じです。

そこで、目的としている缶ビール以外のマスク情報をフィルタするために、画像上での缶ビールのサイズを調べて、そのサイズに基づいてフィルタしてみました。

まず、缶ビールのサイズの幅及び高さの最大、最小値を決めます。

# 缶ビールは、画像サイズの何%を占めるか
height, width, _ = image.shape

minH = int(height * 0.14)
maxH = int(height * 0.29)
minW = int(width * 0.04)
maxW = int(width * 0.09)

そして、SAMで検出したすべてのマスクの外接矩形を求め、サイズが範囲内に収まるものだけをフィルタします。

for contour in contours:
    # 外接の矩形を求める
    x, y, w, h = cv2.boundingRect(contour)

    if (  # 一定のサイズのみ検出対象にする
        minH <= h and h <= maxH and minW <= w and w <= maxW
    ):
        all_contours.append(contour)

フィルタしたマスクだけを表示したものです。フィルタ前のマスクは300個程度で、フィルタ後は30個程度となりました。

作業時の様子ですが、計算した、min/maxの矩形を画面の左端に表示して、缶ビールのサイズと比べながら、丁度よいフィルタのサイズを試行錯誤しました。

4 分類モデルによる銘柄判定

セグメンテーションされた缶ビールの銘柄を求めるためには、YOLOv8の分類モデルを使用しました。

分類モデルを作成した手順は、以下のとおりです。

  • 動画撮影
  • png画像の抽出
  • データセット
  • データ増幅
  • ファインチューニング

(1) 動画撮影

缶ビールは、回転台に乗せて、ちょうど1周分の動画を撮影しました。

各銘柄ごと通常缶とロング缶の2種類を撮影しています。

(2) png画像の抽出

撮影した動画から、データセット用のpng画像を抽出するには、SAMを使用しました。

上記のブログで紹介しているように、動画の1フレーム目で、缶ビールの範囲を指定すると、3フレームごとに缶ビールを切り取り、背景が「透過」と「白」の2種類の画像を生成できます。

データセット用に、背景が白のもの、そして、ブラウザ表示用に透過背景のものを使用しました。

MOVファイルからPNGを切り出したコードです。

mov_2_png.py
import os
import datetime
import numpy as np
import torch
import cv2
import glob
import matplotlib.pyplot as plt
from matplotlib import patches
from segment_anything import sam_model_registry, SamPredictor


# マウスで範囲指定する
class BoundingBox:
    def __init__(self, image):
        self.x1 = -1
        self.x2 = -1
        self.y1 = -1
        self.y2 = -1
        self.image = image.copy()
        plt.figure()
        plt.connect("motion_notify_event", self.motion)
        plt.connect("button_press_event", self.press)
        plt.connect("button_release_event", self.release)
        self.ln_v = plt.axvline(0)
        self.ln_h = plt.axhline(0)
        plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
        plt.show()

    # 選択中のカーソル表示
    def motion(self, event):
        if event.xdata is not None and event.ydata is not None:
            self.ln_v.set_xdata(event.xdata)
            self.ln_h.set_ydata(event.ydata)
            self.x2 = event.xdata.astype("int16")
            self.y2 = event.ydata.astype("int16")
        if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
            plt.clf()
            plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
            ax = plt.gca()
            rect = patches.Rectangle(
                (self.x1, self.y1),
                self.x2 - self.x1,
                self.y2 - self.y1,
                angle=0.0,
                fill=False,
                edgecolor="#00FFFF",
            )
            ax.add_patch(rect)
        plt.draw()

    # ドラッグ開始位置
    def press(self, event):
        self.x1 = event.xdata.astype("int16")
        self.y1 = event.ydata.astype("int16")

    # ドラッグ終了位置、表示終了
    def release(self, event):
        plt.clf()
        plt.close()

    def get_area(self):
        return int(self.x1), int(self.y1), int(self.x2), int(self.y2)


# SAM
class SegmentAnything:
    def __init__(self, device, model_type, sam_checkpoint):
        print("init Segment Anything")
        self.device = device
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(self.device)
        self.predictor = SamPredictor(sam)

    @property
    def contours(self):
        return self._contours

    @property
    def transparent_image(self):
        return self._transparent_image

    @property
    def white_back_image(self):
        return self._white_back_image

    @property
    def box(self):
        return self._box

    # マスク取得
    def predict(self, frame, input_box):
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        self.predictor.set_image(image)
        masks, _, _ = self.predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )

        # ノイズ除去
        self._mask = self._remove_noise(frame, masks[0])
        # 範囲取得
        self._box = self._get_box()
        # 部分画像取得
        self._white_back_image, self._transparent_image = self._get_extract_image(frame)

    # マスクの範囲取得
    def _get_box(self):
        mask_indexes = np.where(self._mask)
        y_min = np.min(mask_indexes[0])
        y_max = np.max(mask_indexes[0])
        x_min = np.min(mask_indexes[1])
        x_max = np.max(mask_indexes[1])
        return np.array([x_min, y_min, x_max, y_max])

    # ノイズ除去
    def _remove_noise(self, image, mask):
        # 2値画像(白及び黒)を生成する
        height, width, _ = image.shape
        tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
        tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
        # マスクによって黒画像の上に白を描画する
        tmp_black_image[:] = np.where(
            mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
        )

        # 輪郭の取得
        contours, _ = cv2.findContours(
            tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        # 最も面積が大きい輪郭を選択
        max_contours = max(contours, key=lambda x: cv2.contourArea(x))
        # 黒画面に一番大きい輪郭だけ塗りつぶして描画する
        black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
        black_image = cv2.drawContours(
            black_image, [max_contours], -1, color=255, thickness=-1
        )
        # 輪郭を保存
        self._contours, _ = cv2.findContours(
            black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # マスクを作り直す
        new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
        new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
        new_mask = np.squeeze(new_mask)

        return new_mask

    # 部分イメージの取得
    def _get_extract_image(self, image):
        # boxの範囲でマスクを切り取る
        part_of_mask = self._mask[
            self._box[1] : self._box[3], self._box[0] : self._box[2]
        ]
        # boxの範囲で元画像を切り取る
        copy_image = image.copy()  # 個々の食品を切取るためのテンポラリ画像
        white_back_image = copy_image[
            self._box[1] : self._box[3], self._box[0] : self._box[2]
        ]
        # boxの範囲で白一色の2値画像を作成する
        h = self._box[3] - self._box[1]
        w = self._box[2] - self._box[0]
        white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8)
        # マスクによって白画像の上に元画像を描画する
        white_back_image[:] = np.where(
            part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image
        )

        transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA)
        transparent_image[np.logical_not(part_of_mask), 3] = 0
        return white_back_image, transparent_image


class Video:
    def __init__(self, filename):
        self.cap = cv2.VideoCapture(filename)
        if self.cap.isOpened() == False:
            print("Video open faild.")
        else:
            self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 次のフレーム取得
    def next_frame(self):
        return self.cap.read()

    ## 総フレーム数
    @property
    def frame_max(self):
        return self._frame_max

    def destroy(self):
        print("video destroy.")
        self.cap.release()
        cv2.destroyAllWindows()


# 縦横それぞれ、0.15倍まで広げた、ボックスを取得する
def get_next_input(box):
    x1 = box[0]
    y1 = box[1]
    x2 = box[2]
    y2 = box[3]
    w = x2 - x1
    h = y2 - y1
    x_margen = int(w * 0.15)
    y_margen = int(h * 0.15)

    return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen])


def main():
    print("PyTorch version:", torch.__version__)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))

    # Segment Anything
    sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth")

    output_path = "./png"

    files = glob.glob("mov/*.MOV")
    for filename in files:
        step = 3
        start_frame = 0  # 684から乱れる
        # start_frame = 525  # 684から乱れる
        # filename = "mov/hon_kirin.MOV"
        basename = os.path.splitext(os.path.basename(filename))[0]
        os.makedirs("{}/{}".format(output_path, basename), exist_ok=True)

        video = Video(filename)

        try:
            print("start")
            for i in range(video.frame_max):
                ret, frame = video.next_frame()
                if ret == False:
                    continue

                # 開始位置まで読み飛ばす
                if i < start_frame:
                    continue

                # フレーム省略
                if i % step != 0:
                    continue

                # 最初のフレームで、バウンディングボックスを取得する
                if i == start_frame:
                    bounding_box = BoundingBox(frame)
                    x1, y1, x2, y2 = bounding_box.get_area()
                    input_box = np.array([x1, y1, x2, y2])

                print(
                    "{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format(
                        datetime.datetime.now(),
                        filename,
                        frame.shape,
                        start_frame,
                        input_box,
                        i + 1,
                        video.frame_max,
                    )
                )

                # マスク生成
                sam.predict(frame, input_box)
                # 輪郭描画
                frame = cv2.drawContours(
                    frame, sam.contours, -1, color=[255, 255, 0], thickness=6
                )

                # バウンディングボックス描画
                frame = cv2.rectangle(
                    frame,
                    pt1=(input_box[0], input_box[1]),
                    pt2=(input_box[2], input_box[3]),
                    color=(255, 255, 255),
                    thickness=2,
                    lineType=cv2.LINE_4,
                )

                # データ保存
                cv2.imwrite(
                    "{}/{}/{:09}_t.png".format(output_path, basename, i),
                    sam.transparent_image,
                )
                cv2.imwrite(
                    "{}/{}/{:09}_w.png".format(output_path, basename, i),
                    sam.white_back_image,
                )

                # 表示
                cv2.imshow("Extract", sam.white_back_image)
                cv2.waitKey(1)
                cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3))
                cv2.waitKey(1)

                # 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する
                input_box = get_next_input(sam.box)
        except KeyboardInterrupt:
            video.destroy()

if __name__ == "__main__":
    main()

(3) データセット

切り取った缶ビールのpng画像は、YOLOv8の分類モデルでファインチューニングするためのデータセットとして使用できるように、定められた階層に配置します。

回転台で1週回した動画は、概ね600フレームとなり、そこから3フレーム毎に画像を切り取るので、それぞれ約200枚となっています。しかし、動画は、全てピタリ1周分とはなっておらず、いくらか長短が出てしまっています。各クラスのデータ数に偏りが出ないようにと、ここでは、上限を最も短かった動画に合わせて186枚としました。

通常缶 186枚、及び、ロング缶 186枚を合計した、372枚が1つのクラス(銘柄)のデータセットで、それを8対2に分割して、学習用と検証用としました。

$ tree dataset -L 2
dataset
├── train
│   ├── asahi_clear
│   ├── asahi_off
│   │   ├── 000_000000001_w.png
│   │   ├── 000_000000004_w.png
│   │   ├── 000_000000007_w.png
・・・略・・・
│   │   ├── 111_000000237_w.png
│   │   ├── 111_000000243_w.png
│   │   └── 111_000000249_w.png
│   ├── asahi_style_free
│   ├── asahi_the_rich
│   ├── asahi_zeitaku_zero
│   ├── kirin_hon_kirin
│   ├── kirin_nodogoshi_nama
│   ├── kirin_nodogoshi_zero
│   ├── kirin_tanrei
│   ├── sappoeo_mugi_to_hop_black
│   ├── sappoeo_mugi_to_hop_extra_rich
│   ├── sapporo_draft_one
│   ├── sapporo_goku_zero
│   ├── sapporo_gold_star
│   ├── sapporo_gold_star_red
│   ├── sapporo_nama_sibori
│   ├── suntory_kinmugi_blue
│   ├── suntory_kinmugi_gold
│   ├── suntory_kinmugi_red
│   └── suntpry_kinmugi_kohakunoaki
└── val
    ├── asahi_clear
    ├── asahi_off
    ├── asahi_style_free
    ├── asahi_the_rich
    ├── asahi_zeitaku_zero
    ├── kirin_hon_kirin
    ├── kirin_nodogoshi_nama
    ├── kirin_nodogoshi_zero
    ├── kirin_tanrei
    ├── sappoeo_mugi_to_hop_black
    ├── sappoeo_mugi_to_hop_extra_rich
    ├── sapporo_draft_one
    ├── sapporo_goku_zero
    ├── sapporo_gold_star
    ├── sapporo_gold_star_red
    ├── sapporo_nama_sibori
    ├── suntory_kinmugi_blue
    ├── suntory_kinmugi_gold
    ├── suntory_kinmugi_red
    └── suntpry_kinmugi_kohakunoaki

ファイル名の数値部分は、フレーム数です。また、プレフィックスとして000を付与されているものが、通常缶、111が、ロング缶の画像です。

PNG画像をデータセットに配置したコードです。

png_2_dataset.py
import os
import shutil
import glob

input_path = "./png"
dataset_path = "./dataset"

stage_list = ["train", "val"]
max = 186  # あまりデータ数に差が出ないように、多くあるものは、一定数以上を削除する

# pngディレクトリ内のディレクトリを列挙
class_name_list = []
dir_list = glob.glob("{}/*".format(input_path))
for dir in dir_list:
    basename = os.path.basename(dir)
    class_name_list.append(basename)

print("class_name_list:{}".format(class_name_list))

# dataset作成
os.makedirs(dataset_path, exist_ok=True)
for stage in stage_list:
    os.makedirs("{}/{}".format(dataset_path, stage), exist_ok=True)

for class_name in class_name_list:
    if class_name.endswith("_000"):
        kind = "000"  # レギュラー缶
    else:
        kind = "111"  # ロング缶

    cls_name = class_name.replace("_000", "").replace("_001", "")

    for stage in stage_list:
        os.makedirs("{}/{}/{}".format(dataset_path, stage, cls_name), exist_ok=True)

    file_list = glob.glob(
        "{}/{}/*_w.png".format(input_path, class_name)
    )  # 白バックのPNGのみ対象とする
    file_list.sort()

    file_len = len(file_list)
    train_len = int(file_len * 0.8)
    val_len = file_len - train_len

    print("{} {} train:{} val:{}".format(class_name, file_len, train_len, val_len))

    val_counter = 0
    for i, file in enumerate(file_list):
        if i > max:
            break
        stage = stage_list[i % 2]
        if i % 2 == 1:
            val_counter += 1
            if val_len < val_counter:
                stage = stage_list[0]  # trainに変更
        basename = os.path.basename(file)
        input_file = "{}/{}/{}".format(input_path, class_name, basename)

        # 出力時は、通常缶とロング缶を同じクラスにまとめる
        # output_file = "{}/{}/{}/{}".format(dataset_path, stage, class_name, basename)
        output_file = "{}/{}/{}/{}_{}".format(
            dataset_path, stage, cls_name, kind, basename
        )

        shutil.copyfile(input_file, output_file)

(4) データ増幅

データは、下記の手順で約10倍(各クラス 3720枚 train:2976 val:744)に増幅しています。

変換に使用したパラメータです。

# 変換リスト
convertList = [
    {"function": saturation, "param": 0.2},  # 彩度
    {"function": saturation, "param": 3.0},
    {"function": brightness, "param": 0.6},  # 明度
    {"function": contrast, "param": 0.8},  # コントラスト
    {"function": mosaic, "param": 0.05},  # モザイク
    {"function": mosaic, "param": 0.1},
    {"function": mosaic, "param": 0.2},
    {"function": gaussian, "param": 20.0},  # ガウスノイズ
    {"function": noise, "param": 100},  # ごま塩ノイズ
]

この作業で、学習用、検証用共に、1枚のpng画像が、10倍に増幅されることになります。

amplify.py
import os
import glob
import cv2
import numpy as np
from matplotlib import pyplot as plt


# 彩度
def saturation(src, saturation):
    # 一旦、BGRをHSVに変換して彩度を変換する
    img = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
    img[:, :, (1)] = img[:, :, (1)] * saturation
    img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
    return img


# 明度
def brightness(src, brightness):
    # 一旦、BGRをHSVに変換して明度を変換する
    img = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
    img[:, :, (2)] = img[:, :, (2)] * brightness
    img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
    return img


# コントラスト
def contrast(src, alpha):
    # 各色相をalphaで演算する
    img = alpha * src
    return np.clip(img, 0, 255).astype(np.uint8)


# モザイク
def mosaic(src, ratio):
    # ratio倍でリサイズして戻す
    img = cv2.resize(src, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
    return cv2.resize(img, src.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)


# ガウスノイズ
def gaussian(src, sigma):
    # ランダム値で生成した画像と合成する
    row, col, ch = src.shape
    mean = 0
    gauss = np.random.normal(mean, sigma, (row, col, ch)).astype("u1")
    gauss = gauss.reshape(row, col, ch)
    return src + gauss


# ごま塩ノイズ
def noise(src, sigma):
    img = src.copy()
    # 画素数xRGBのノイズを生成
    noise = np.random.normal(0, sigma, img.shape)
    # ノイズを付加して8bitの範囲にクリップ
    noisy_img = img.astype(np.float64) + noise
    return np.clip(noisy_img, 0, 255).astype(np.uint8)


def get_image_list(dataset_path, class_name, stage):
    image_list = []
    files = glob.glob("{}/{}/{}/*".format(dataset_path, stage, class_name))
    for file in files:
        image_list.append(file)
    return image_list


def main():
    dataset_path = "./dataset"

    # 変換リスト
    convertList = [
        {"function": saturation, "param": 0.2},  # 彩度
        {"function": saturation, "param": 3.0},
        {"function": brightness, "param": 0.6},  # 明度
        {"function": contrast, "param": 0.8},  # コントラスト
        {"function": mosaic, "param": 0.05},  # モザイク
        {"function": mosaic, "param": 0.1},
        {"function": mosaic, "param": 0.2},
        {"function": gaussian, "param": 20.0},  # ガウスノイズ
        {"function": noise, "param": 100},  # ごま塩ノイズ
    ]

    # クラス一覧
    class_name_list = []
    dir_list = glob.glob("{}/train/*".format(dataset_path))
    for dir in dir_list:
        basename = os.path.basename(dir)
        class_name_list.append(basename)

    # class_name_list = ["kirin_tanrei"]
    print(class_name_list)

    stage_list = ["train", "val"]
    for class_name in class_name_list:
        for stage in stage_list:
            image_list = get_image_list(dataset_path, class_name, stage)
            print("images: {}件 {} {}".format(len(image_list), class_name, stage))

            for image in image_list:
                basename = os.path.basename(image)
                dirname = os.path.dirname(image)
                filename = os.path.splitext(os.path.basename(image))[0]
                ext = os.path.splitext(os.path.basename(image))[1]
                ext = ext.replace(".", "")
                print(
                    "basename:{} dirname:{} basename_without_ext:{} ext:{}".format(
                        basename, dirname, filename, ext
                    )
                )
                # 変換リストに基づく増幅処理
                for i, convert in enumerate(convertList):
                    # 出力ファイル名
                    output_image = "{}/{}_{}.{}".format(dirname, filename, i, ext)
                    print("output_image: {}".format(output_image))
                    img = cv2.imread(image)
                    img = convert["function"](img, convert["param"])
                    # 変換後の画像のエクスポート
                    cv2.imwrite(output_image, img)

            # print("増幅後データ: {}件 ".format(len(outputManifest.split("\n"))))

    for class_name in class_name_list:
        for stage in stage_list:
            image_list = get_image_list(dataset_path, class_name, stage)
            print("images: {}件 {} {}".format(len(image_list), class_name, stage))

main()

(5) ファインチューニング

YOLOv8では、下記のように簡単にファインチューニングが出来ます。

model = YOLO("yolov8n-cls.pt")  # load a pretrained model 
results = model.train(data="./dataset", epochs=10, imgsz=640)

下記は、モデルの性能を確認するため、テスト的に推論してみた結果ですが、非常に高い正解率となっています。

🌝 0.89 asahi_clear_001 asahi_clear
🌝 0.88 asahi_clear_002 asahi_clear
🌝 1.00 asahi_off_001 asahi_off
🌝 1.00 asahi_off_002 asahi_off
🌝 1.00 asahi_style_free_001 asahi_style_free
🌝 1.00 asahi_style_free_002 asahi_style_free
❌ 0.76 asahi_the_rich_001 sappoeo_mugi_to_hop_black
🌝 0.97 asahi_the_rich_002 asahi_the_rich
🌝 0.69 asahi_zeitaku_zero_001 asahi_zeitaku_zero
🌝 0.84 asahi_zeitaku_zero_002 asahi_zeitaku_zero
🌝 0.85 kirin_hon_kirin_001 kirin_hon_kirin
🌝 0.76 kirin_hon_kirin_002 kirin_hon_kirin
🌝 0.64 kirin_nodogoshi_nama_001 kirin_nodogoshi_nama
🌝 0.80 kirin_nodogoshi_nama_002 kirin_nodogoshi_nama
🌝 0.99 kirin_nodogoshi_zero_001 kirin_nodogoshi_zero
🌝 0.86 kirin_nodogoshi_zero_002 kirin_nodogoshi_zero
🌝 0.98 kirin_tanrei_001 kirin_tanrei
🌝 0.96 kirin_tanrei_002 kirin_tanrei
🌝 0.98 sappoeo_mugi_to_hop_extra_rich_001 sappoeo_mugi_to_hop_extra_rich
🌝 0.95 sappoeo_mugi_to_hop_extra_rich_002 sappoeo_mugi_to_hop_extra_rich
🌝 0.98 sapporo_draft_one_001 sapporo_draft_one
🌝 1.00 sapporo_draft_one_002 sapporo_draft_one
🌝 0.71 sapporo_goku_zero_001 sapporo_goku_zero
🌝 0.83 sapporo_goku_zero_002 sapporo_goku_zero
🌝 1.00 sapporo_gold_star_red_001 sapporo_gold_star_red
🌝 0.96 sapporo_gold_star_red_002 sapporo_gold_star_red
🌝 0.87 sapporo_nama_sibori_001 sapporo_nama_sibori
🌝 0.69 sapporo_nama_sibori_002 sapporo_nama_sibori
🌝 1.00 suntory_kinmugi_blue_001 suntory_kinmugi_blue
🌝 1.00 suntory_kinmugi_blue_002 suntory_kinmugi_blue
🌝 1.00 suntory_kinmugi_gold_001 suntory_kinmugi_gold
🌝 0.99 suntory_kinmugi_gold_002 suntory_kinmugi_gold

これは、データセットの画像と、推論対象の画像が、非常に似ているからだと考えています。

下図は一例ですが、上段が、セグメンテーションで切り取られた推論対象の画像、そして、下段が、データセットの画像です。どちらも、概ね横からみた画像で背景が白となっているため、類似性はかなり高いと思います。

5 検出及び判定

SAMによるセグメンテーション、サイズによるフィルタ、そして、分類モデルによる銘柄判定の一連のコードは、下記となります。

beer_detection.py
import torch
import torchvision
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from ultralytics import YOLO
import numpy as np
import cv2
import shutil
import os
import glob

print("PyTorch version: {}".format(torch.__version__))
print("Torchvision version: {}".format(torchvision.__version__))
print("CUDA is available: {}".format(torch.cuda.is_available()))


def create_json(basename, all_contours, all_indexes):
    json_file_name = "{}.json".format(basename)
    print("create json :{}".format(json_file_name))

    str = '{\n\t"data": [\n'

    for n, contour in enumerate(all_contours):
        str += "\t\t{\n"

        str += '\t\t\t"contour": ['
        for i, c in enumerate(contour):
            if i != 0:
                str += ","
            str += "[{},{}]".format(c[0][0], c[0][1])

        str += "],\n"

        str += '\t\t\t"classIndex": {}\n'.format(all_indexes[n])
        str += "\t\t}"
        if n != len(all_contours) - 1:
            str += ","
        str += "\n"

    str += "\t]\n}\n"

    with open(json_file_name, "w") as f:
        f.write(str)


def main():
    train = "train5"
    model_path = "runs/classify/{}/weights/best.pt".format(train)
    print("YOLOv8 Image Classification initialize. mode_path:{}".format(model_path))
    model = YOLO(model_path)

    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda"
    print(
        "SAM initialize. checkpoint:{} model_type:{}".format(sam_checkpoint, model_type)
    )

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    mask_generator_ = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.9,
        stability_score_thresh=0.96,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )

    print("job start.")
    files = glob.glob("./shop_images/*")
    for file in files:
        basename = os.path.splitext(os.path.basename(file))[0]
        org_image = cv2.imread(file)
        copy_image = org_image.copy()
        image = cv2.cvtColor(org_image, cv2.COLOR_BGR2RGB)

        print("file:{} basename:{} shape:{}".format(file, basename, image.shape))

        # 検出されたマスクかを、サイズでフィルタするための値
        height, width, _ = image.shape
          minH = int(height * 0.14)
          maxH = int(height * 0.29)
          minW = int(width * 0.04)
          maxW = int(width * 0.09)

        cv2.rectangle(copy_image, (0, 0), (minW, minH), (255, 255, 0), 2)
        cv2.rectangle(copy_image, (0, 0), (maxW, maxH), (255, 255, 0), 2)

        # ブログ用(切り取った画像を保存する)
        beer_images = "./beer_images"
        if os.path.isdir(beer_images):
            shutil.rmtree(beer_images)
        os.makedirs(beer_images, exist_ok=True)

        print("start generate")
        masks = mask_generator_.generate(image)
        print("masks: {} ".format(len(masks)))

        np.random.seed(seed=32)

        all_contours = []  # 1画像に含まれるすべての缶ビール分
        all_indexes = []

        for i, mask_data in enumerate(masks):
            mask = mask_data["segmentation"]
            mask = mask.astype(np.uint8)

            # 2値画像(白及び黒)を生成する
            # height, width, _ = image.shape
            tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
            tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
            # マスクによって黒画像の上に白を描画する
            tmp_black_image[:] = np.where(
                mask[:height, :width, np.newaxis] == True,
                tmp_white_image,
                tmp_black_image,
            )
            # 輪郭の取得
            contours, _ = cv2.findContours(
                tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
            )
            contours2 = []  # 缶ビール1個分(フィルタ後)
            for contour in contours:
                # 外接の矩形を求める
                x, y, w, h = cv2.boundingRect(contour)

                # 外接矩形
                if (  # 一定のサイズのみ検出対象にする
                    minH <= h and h <= maxH and minW <= w and w <= maxW
                ):
                    contours2.append(contour)
                    all_contours.append(contour)
                    # boxの範囲でマスクを切り取る
                    part_of_mask = mask[y : y + h, x : x + w]
                    # boxの範囲で元画像を切り取る
                    tmp_image = org_image.copy()  # 個々の食品を切取るためのテンポラリ画像
                    white_back_image = tmp_image[y : y + h, x : x + w]
                    # boxの範囲で白一色の2値画像を作成する
                    white_image = np.full(np.array([h, w, 3]), 255, dtype=np.uint8)
                    # マスクによって白画像の上に元画像を描画する
                    white_back_image[:] = np.where(
                        part_of_mask[:h, :w, np.newaxis] == False,
                        white_image,
                        white_back_image,
                    )

                    # ブログ用(切り取った画像を保存する)
                    cv2.imwrite(
                        "./{}/{}_{}.png".format(
                            beer_images, basename, len(all_contours)
                        ),
                        white_back_image,
                    )

                    tmp_image = "./tmp.png"
                    cv2.imwrite(
                        tmp_image,
                        white_back_image,
                    )
                    # 各缶ビール画像を分類モデルで判定する
                    results = model(tmp_image)
                    names = results[0].names
                    probs = results[0].probs
                    i = probs.top5[0]
                    print(" {:.2f} {}".format(probs.data[i], names[i]))
                    all_indexes.append(i)

            # 輪郭を描画
            cv2.drawContours(copy_image, contours2, -1, (255, 0, 0), 2)

        # 1画像で検出されたマスク数
        print("contours: {}".format(len(all_contours)))
        print("all_indexes: {}".format(all_indexes))

        # 1画像分のデータファイルとして出力する
        create_json(basename, all_contours, all_indexes)

        cv2.imshow("Mask", copy_image)
        cv2.waitKey(0)  # 画像にフォーカスをあてて、キーを押すと、次に進む


main()

そして、ここで出力されるは、ブラウザで表示するためのJSONファイルです。 JSONファイルには、セグメンテーションの座標及び、銘柄を表すクラスインデックスが含まれています。

{
    "data": [
        {
            "contour": [[387,187],[386,188],[378,188],[377,189],[367,189],
                ・・・略・・・
                [426,189],[425,188],[419,188],[418,187]],
            "classIndex": 17
        },
        {
            "contour": [[315,186],[314,187],[313,187],[312,188],[311,187],
                ・・・略・・・
                [350,187],[349,188],[348,187],[323,187],[322,186]],
            "classIndex": 15
        },
                ・・・略・・・
}

6 ブラウザ表示

上記で作成されたJSONを使用して、Reactで検出結果を表示しています。

検出対象の画像(shop001〜shop004.png)、検出結果のJSON(shop001〜shop004.json)などは、assetsに配置して使用しました。

sample-app % tree -L 2
.
├── node_modules
├── public
├── src
│   ├── App.css
│   ├── App.test.tsx
│   ├── App.tsx
│   ├── assets
│   │   └── data
│   │       ├── beer000.png
│   │       ├── beer001.png
・・・略・・・
│   │       ├── beer018.png
│   │       ├── beer019.png
│   │       ├── names.json
│   │       ├── shop001.json
│   │       ├── shop001.png
・・・略・・・
│   │       ├── shop004.json
│   │       └── shop004.png
│   ├── components
│   ├── index.css
│   ├── index.tsx
│   ├── react-app-env.d.ts
│   ├── reportWebVitals.ts
│   └── setupTests.ts
├── tsconfig.json
└── yarn.lock

メインとなるコンポーネントは、以下のようになっています。

Stage.tsx
import React, { useEffect, useState } from "react";
import targetImage001 from "../assets/data/shop001.png";
import targetData001 from "../assets/data/shop001.json";
import targetImage002 from "../assets/data/shop002.png";
import targetData002 from "../assets/data/shop002.json";
import targetImage003 from "../assets/data/shop003.png";
import targetData003 from "../assets/data/shop003.json";
import targetImage004 from "../assets/data/shop004.png";
import targetData004 from "../assets/data/shop004.json";

import beerImage999 from "../assets/data/beer999.png";
import beerImage000 from "../assets/data/beer000.png";
import beerImage001 from "../assets/data/beer001.png";
import beerImage002 from "../assets/data/beer002.png";
import beerImage003 from "../assets/data/beer003.png";
import beerImage004 from "../assets/data/beer004.png";
import beerImage005 from "../assets/data/beer005.png";
import beerImage006 from "../assets/data/beer006.png";
import beerImage007 from "../assets/data/beer007.png";
import beerImage008 from "../assets/data/beer008.png";
import beerImage009 from "../assets/data/beer009.png";
import beerImage010 from "../assets/data/beer010.png";
import beerImage011 from "../assets/data/beer011.png";
import beerImage012 from "../assets/data/beer012.png";
import beerImage013 from "../assets/data/beer013.png";
import beerImage014 from "../assets/data/beer014.png";
import beerImage015 from "../assets/data/beer015.png";
import beerImage016 from "../assets/data/beer016.png";
import beerImage017 from "../assets/data/beer017.png";
import beerImage018 from "../assets/data/beer018.png";
import beerImage019 from "../assets/data/beer019.png";

import classNames from "../assets/data/names.json";
import "./Stage.css";
import { Button } from "@material-ui/core";

export const Stage: React.FC = () => {
  const [info, setInfo] = useState({
    x: -1,
    y: -1,
    mouseIndex: -1,
    classIndex: -1,
    className: "",
  });
  const [targetIndex, setTargetIndex] = useState(0);
  // shop001 shop002 shop003
  const canvasWidth = 1300;
  const canvasHeight = 800;
  const canvasColor = "#000000";
  const canvasId = "canvas";

  //const targetIndex = 0;
  const targetImageList = [
    targetImage001,
    targetImage002,
    targetImage003,
    targetImage004,
  ];
  const targetDataList = [
    targetData001,
    targetData002,
    targetData003,
    targetData004,
  ];
  const beerImageList = [
    beerImage999,
    beerImage000,
    beerImage001,
    beerImage002,
    beerImage003,
    beerImage004,
    beerImage005,
    beerImage006,
    beerImage007,
    beerImage008,
    beerImage009,
    beerImage010,
    beerImage011,
    beerImage012,
    beerImage013,
    beerImage014,
    beerImage015,
    beerImage016,
    beerImage017,
    beerImage018,
    beerImage019,
  ];

  const getContext = (): CanvasRenderingContext2D => {
    return getCavasElement().getContext("2d")!;
  };

  const getCavasElement = (): HTMLCanvasElement => {
    return document.getElementById(canvasId) as HTMLCanvasElement;
  };

  const dwawContour = (
    ctx: CanvasRenderingContext2D,
    contour: number[][],
    lineWidth: number,
    lineColor: string
  ) => {
    ctx.lineWidth = lineWidth;
    ctx.strokeStyle = lineColor;
    ctx.beginPath();
    for (var c = 0; c < contour.length; c++) {
      const point = contour;
      if (c === 0) {
        ctx.moveTo(point[0], point[1]);
      } else if (c === contour.length - 1) {
        ctx.closePath();
      } else {
        ctx.lineTo(point[0], point[1]);
      }
    }
    const point = contour[0];
    ctx.moveTo(point[0], point[1]);
    for (const point of contour) {
      ctx.lineTo(point[0], point[1]);
    }
    ctx.stroke();
  };

  // 多角形の範囲内かどうか判定
  const isRange = (contour: number[][], x: number, y: number) => {
    let minX = 9999;
    let maxX = 0;
    let minY = 9999;
    let maxY = 0;
    contour.forEach((point) => {
      if (maxX < point[0]) {
        maxX = point[0];
      }
      if (point[0] < minX) {
        minX = point[0];
      }
      if (maxY < point[1]) {
        maxY = point[1];
      }
      if (point[1] < minY) {
        minY = point[1];
      }
    });
    if (minX <= x && x <= maxX && minY <= y && y <= maxY) {
      return true;
    }
    return false;
  };

  const getIndex = (targetData: any, mouseX: number, mouseY: number) => {
    let index = -1;
    targetData.data.forEach((data: any, i: number) => {
      const contour: number[][] = data["contour"] as number[][];
      if (isRange(contour, mouseX, mouseY)) {
        index = i;
      }
    });
    return index;
  };

  useEffect(() => {
    const ctx = getContext();

    // 初期描画
    ctx!.fillStyle = canvasColor;
    ctx!.fillRect(0, 0, canvasWidth, canvasHeight);
    const img = new Image();
    img.src = targetImageList[targetIndex];
    img.onload = () => {
      ctx!.drawImage(img, 0, 0); // 画像は拡大縮小しないで表示する
    };

    const handleWindowMouseMove = (mouseEvent: MouseEvent) => {
      // 画像の位置
      const canvasElementRect = getCavasElement().getBoundingClientRect();
      // マウスの相対座標
      const mouseX = mouseEvent.clientX - canvasElementRect.left;
      const mouseY = mouseEvent.clientY - canvasElementRect.top;

      // マウス位置が、どのマスクに一致しているかどうか
      const mouseIndex = getIndex(targetDataList[targetIndex], mouseX, mouseY);
      // マウス位置のオブジェクトのクラスインデックス
      const classIndex =
        mouseIndex !== -1
          ? (targetDataList[targetIndex].data[mouseIndex][
              "classIndex"
            ] as number)
          : -1;

      const className = classIndex !== -1 ? classNames.names[classIndex] : "";
      setInfo({
        x: mouseX,
        y: mouseY,
        mouseIndex: mouseIndex,
        classIndex: classIndex,
        className: className,
      });
      console.log(
        `x:${mouseX} y:${mouseY} mouseIndex:${mouseIndex} classIndex:${classIndex} ${className}`
      );

      // マウスの位置に基づく描画
      const ctx = getContext();

      const img = new Image();
      img.src = targetImageList[targetIndex];
      ctx!.drawImage(img, 0, 0); // 画像は拡大縮小しないで表示する

      if (0 < mouseX && 0 < mouseY) {
        targetDataList[targetIndex].data.forEach((data, i) => {
          const contour: number[][] = data["contour"] as number[][];
          let lineColor = "#0000ff";
          let lineWidth = 3;
          if (i === mouseIndex) {
            lineColor = "#aaffff";
            lineWidth = 5;
          }
          dwawContour(ctx, contour, lineWidth, lineColor);
        });
      }
    };
    window.addEventListener("mousemove", handleWindowMouseMove);
  }, [targetIndex]); // 再描画は、dispLine()に移譲する

  return (
    <div>
      <div className="selector">
        {[...Array(4)].map((_, i) => {
          return (
            <Button
              variant="outlined"
              color="primary"
              onClick={() => {
                setTargetIndex(i);
              }}
            >
              Image_{i}
            </Button>
          );
        })}
      </div>
      <div className="title">
        <img src={beerImageList[info.classIndex + 1]} />
        <span className="title">{info.className}</span>
      </div>
      <canvas id={canvasId} width={canvasWidth} height={canvasHeight}></canvas>
      <div></div>
      <div className="info">
        targetIndex={targetIndex} x={info.x} y={info.y} mouseIndex:
        {info.mouseIndex} classIndex:
        {info.classIndex}
      </div>
    </div>
  );
};

7 最後に

今回は、スーバーマーケットで陳列されている商品(缶ビール)を検出してみました。SAMが登場するまでの手法だと、どうしても、データセット作成にはかなりの工数が必要になったと思います。恐らく、個人できるような作業では無かったのではないか考えています。

下記は、醤油の棚です。まだ、試しておりませんが、今回の手法を応用すれば、下記のような商品棚も、うまく検出できるのでは?と妄想しております。

写真を取らせていただいた店内は、生活協同組合コープさっぽろでした。こちらでは、入り口に 「店内撮影OK」 と掲示されており、お言葉に甘えて撮影させて頂きました。ありがとうございます。

データセットを作成するために買ってきた缶ビールは、今、押し入れに積み上がっています。この後、飲みます。\(^o^)/