[YOLOv8 Instance Segmentation] 「きのこの山」に潜伏する「たけのこの里」を機械学習で見つけてみました 〜データセットは、Segment Anything Modelで自動的に生成されています〜

2024.02.19

1 はじめに

CX事業本部製造ビジネステクノロジー部の平内(SIN)です。

YOLOv8は、イメージ分類・物体検出・セグメンテーション・骨格検出などに対応していますが、今回は、セグメンテーションモデルをファインチューニングして、「きのこの山」と「たけのこの里」を検出してみました。

最初に、動作している様子をご確認下さい。比較的に精度高くセグメンテーション出来ていると思います。

2 データセット作成

セグメンテーションモデルを学習する為のデータは、下記のように、対象物の輪郭座標が必要であり、これを大量に作成するのは、結構な膨大な作業量になってしまいます。

そこで、この作業は、Segment Anything Model(SAM)を使用し、対象物を撮影した動画から自動的に作成してみました。 作成されたデータセットは、3,000枚の画像、27,000個のアノテーションとなりました。

データセット作成の工程は、概ね以下のとおりです。

  • (1) 動画の撮影
  • (2) オブジェクトの切出し
  • (3) 背景と合成してデータセット作成
  • (4) YOLOv8形式への変換

(1) 動画の撮影

撮影は、回転台の上に載せてスマフォで撮っています。

下に、紺色の布を敷いているのは、影が目立つと、この後、Segment Anything Modelでオブジェクトを切り抜く際に、影も対象物とご判断されてしまうことを避けるためです。

撮影された動画は、以下のような感じです。

(2) オブジェクトの切出し

オブジェクト切出しのプログラムでは、SAM(Segment Anything Model)を使用して動画の各フレームから対象オブジェクトを切出しています。

対象物の検出範囲は、動画の1フレーム目で、設定するようになっています。

最初のフレームで、対象物を検出できたら、その物体の一回り大きな枠を、次のフレームの検出範囲としているため対象物が、少々ずれても、追従させながら切り出すことができるようになっています。

参考

切り出されたデータは、以下のとおりとなります。

  • xxxxxxxxxx_w.png 背景が白の切り取り画像
  • xxxxxxxxxx_t.png 背景が透過の切り取り画像(今回、未使用)
  • xxxxxxxxxx_l.txt 切り取り画像の左上を原点(0,0)とした輪郭の座標
dataset
└── output_png
    └── KINOKO
    |   ├── 000000000_l.txt
    |   ├── 000000000_t.png
    |   ├── 000000000_w.png
...
    |   ├── 000000102_l.txt
    |   ├── 000000102_t.png
    |   └── 000000102_w.png
    └── TAKENOKO
        ├── 000000000_l.txt
        ├── 000000000_t.png
        ├── 000000000_w.png
...
        ├── 000000102_l.txt
        ├── 000000102_t.png
        └── 000000102_w.png

以下の動画は、オブジェクトの切出し作業をしているところです。

そして、オブジェクトの切出しのプログラムです。

mp4_2_png.py
# SAMによる mp4ファイルからのターゲット摘出

import os
import datetime
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
from matplotlib import patches
from segment_anything import sam_model_registry, SamPredictor


# マウスで範囲指定する
class BoundingBox:
    def __init__(self, image):
        self.x1 = -1
        self.x2 = -1
        self.y1 = -1
        self.y2 = -1
        self.image = image.copy()
        plt.figure()
        plt.connect("motion_notify_event", self.motion)
        plt.connect("button_press_event", self.press)
        plt.connect("button_release_event", self.release)
        self.ln_v = plt.axvline(0)
        self.ln_h = plt.axhline(0)
        plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
        plt.show()

    # 選択中のカーソル表示
    def motion(self, event):
        if event.xdata is not None and event.ydata is not None:
            self.ln_v.set_xdata(event.xdata)
            self.ln_h.set_ydata(event.ydata)
            self.x2 = event.xdata.astype("int16")
            self.y2 = event.ydata.astype("int16")
        if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
            plt.clf()
            plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
            ax = plt.gca()
            rect = patches.Rectangle(
                (self.x1, self.y1),
                self.x2 - self.x1,
                self.y2 - self.y1,
                angle=0.0,
                fill=False,
                edgecolor="#00FFFF",
            )
            ax.add_patch(rect)
        plt.draw()

    # ドラッグ開始位置
    def press(self, event):
        self.x1 = event.xdata.astype("int16")
        self.y1 = event.ydata.astype("int16")

    # ドラッグ終了位置、表示終了
    def release(self, event):
        plt.clf()
        plt.close()

    def get_area(self):
        return int(self.x1), int(self.y1), int(self.x2), int(self.y2)


# SAM
class SegmentAnything:
    def __init__(self, device, model_type, sam_checkpoint):
        print("init Segment Anything")
        self.device = device
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(self.device)
        self.predictor = SamPredictor(sam)

    @property
    def contours(self):
        return self._contours

    @property
    def transparent_image(self):
        return self._transparent_image

    @property
    def white_back_image(self):
        return self._white_back_image

    @property
    def box(self):
        return self._box

    @property
    def segment(self):
        return self._segment

    # マスク取得
    def predict(self, frame, input_box):
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        self.predictor.set_image(image)
        masks, _, _ = self.predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )

        # ノイズ除去
        self._mask = self._remove_noise(frame, masks[0])
        # 範囲取得
        self._box = self._get_box()
        # 部分画像取得
        self._white_back_image, self._transparent_image = self._get_extract_image(frame)
        # セグメンテーション取得
        self._segment = self._get_segment()

    # マスクの範囲取得
    def _get_box(self):
        mask_indexes = np.where(self._mask)
        y_min = np.min(mask_indexes[0])
        y_max = np.max(mask_indexes[0])
        x_min = np.min(mask_indexes[1])
        x_max = np.max(mask_indexes[1])
        return np.array([x_min, y_min, x_max, y_max])

    # ノイズ除去
    def _remove_noise(self, image, mask):
        # 2値画像(白及び黒)を生成する
        height, width, _ = image.shape
        tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
        tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
        # マスクによって黒画像の上に白を描画する
        tmp_black_image[:] = np.where(
            mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
        )

        # 輪郭の取得
        contours, _ = cv2.findContours(
            tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        # 最も面積が大きい輪郭を選択
        max_contours = max(contours, key=lambda x: cv2.contourArea(x))
        # 黒画面に一番大きい輪郭だけ塗りつぶして描画する
        black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
        black_image = cv2.drawContours(
            black_image, [max_contours], -1, color=255, thickness=-1
        )
        # 輪郭を保存
        self._contours, _ = cv2.findContours(
            black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # マスクを作り直す
        new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
        new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
        new_mask = np.squeeze(new_mask)

        return new_mask

    # 部分イメージの取得
    def _get_extract_image(self, image):
        # boxの範囲でマスクを切り取る
        part_of_mask = self._mask[
            self._box[1] : self._box[3], self._box[0] : self._box[2]
        ]
        # boxの範囲で元画像を切り取る
        copy_image = image.copy()  # 個々の食品を切取るためのテンポラリ画像
        white_back_image = copy_image[
            self._box[1] : self._box[3], self._box[0] : self._box[2]
        ]
        # boxの範囲で白一色の2値画像を作成する
        h = self._box[3] - self._box[1]
        w = self._box[2] - self._box[0]
        white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8)
        # マスクによって白画像の上に元画像を描画する
        white_back_image[:] = np.where(
            part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image
        )

        transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA)
        transparent_image[np.logical_not(part_of_mask), 3] = 0
        return white_back_image, transparent_image

    def _get_segment(self):
        box = self._box.tolist()
        contours = self._contours[0].tolist()
        width = box[2] - box[0]
        height = box[3] - box[1]
        # print("box type:{} len:{}".format(type(box), len(box)))
        # print("v type:{} len:{}".format(type(contours), len(contours)))
        # print("box {},{},{},{}".format(box[0], box[1], box[2], box[3]))
        # print("width:{} height:{}".format(width, height))
        text = ""
        for i, data in enumerate(contours):
            d = data[0]
            x = d[0] - box[0]
            y = d[1] - box[1]
            text += "{},{},".format(x, y)
        return text


class Video:
    def __init__(self, filename):
        self.cap = cv2.VideoCapture(filename)
        if self.cap.isOpened() == False:
            print("Video open faild.")
        else:
            self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 次のフレーム取得
    def next_frame(self):
        return self.cap.read()

    ## 総フレーム数
    @property
    def frame_max(self):
        return self._frame_max

    def destroy(self):
        print("video destroy.")
        self.cap.release()
        cv2.destroyAllWindows()


# 縦横それぞれ、0.15倍まで広げた、ボックスを取得する
def get_next_input(box):
    x1 = box[0]
    y1 = box[1]
    x2 = box[2]
    y2 = box[3]
    w = x2 - x1
    h = y2 - y1
    x_margen = int(w * 0.15)
    y_margen = int(h * 0.15)

    return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen])


def main():
    print("PyTorch version:", torch.__version__)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))

    step = 3
    start_frame = 0
    filename = "KINOKO.mp4"
    # filename = "TAKENOKO.mp4"
    output_path = "./dataset/output_png"
    basename = os.path.splitext(os.path.basename(filename))[0]
    os.makedirs("{}/{}".format(output_path, basename), exist_ok=True)

    video = Video(filename)

    # Segment Anything
    sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth")

    try:
        print("start")
        for i in range(video.frame_max):
            ret, frame = video.next_frame()
            if ret == False:
                continue

            # 開始位置まで読み飛ばす
            if i < start_frame:
                continue

            # フレーム省略
            if i % step != 0:
                continue

            # 最初のフレームで、バウンディングボックスを取得する
            if i == start_frame:
                bounding_box = BoundingBox(frame)
                x1, y1, x2, y2 = bounding_box.get_area()
                input_box = np.array([x1, y1, x2, y2])

            print(
                "{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format(
                    datetime.datetime.now(),
                    filename,
                    frame.shape,
                    start_frame,
                    input_box,
                    i + 1,
                    video.frame_max,
                )
            )

            # マスク生成
            sam.predict(frame, input_box)
            # 輪郭描画
            frame = cv2.drawContours(
                frame, sam.contours, -1, color=[255, 255, 0], thickness=6
            )

            # バウンディングボックス描画
            frame = cv2.rectangle(
                frame,
                pt1=(input_box[0], input_box[1]),
                pt2=(input_box[2], input_box[3]),
                color=(255, 255, 255),
                thickness=2,
                lineType=cv2.LINE_4,
            )

            # データ保存
            cv2.imwrite(
                "{}/{}/{:09}_t.png".format(output_path, basename, i),
                sam.transparent_image,
            )
            cv2.imwrite(
                "{}/{}/{:09}_w.png".format(output_path, basename, i),
                sam.white_back_image,
            )
            # 座標保存
            with open(
                "{}/{}/{:09}_l.txt".format(output_path, basename, i), mode="w"
            ) as f:
                f.write(sam.segment)

            # 表示
            cv2.imshow("Extract", sam.white_back_image)
            cv2.waitKey(1)
            cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3))
            cv2.waitKey(1)

            # 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する
            input_box = get_next_input(sam.box)

    except KeyboardInterrupt:
        video.destroy()


if __name__ == "__main__":
    main()

(3) 背景と合成してデータセット作成

切出した背景透過のデータを背景に合成して、データセット用の画像を生成します。

背景への合成は、ランダム位置で、ランダムなサイズで行われます。

なお、背景に対する、オブジェクトのサイズが一定の範囲に収まるように、変換は下記を基準としています。

  • 背景画像 640 * 480
  • 切り出し画像 横幅が200になるようリサイズしてから、0.3〜1倍に変換する

切出し画像を貼り付けた後は、相対座標のデータもリサイズに合わせて変換し、最終的に背景画像の幅を基準に0〜1で正規化され、画像と同名の.txt として出力しています。

この時点で.txtは、YOLOのデータセット形式となっています。

(クラスIDが決定されるのは、この時点となります)

./dataset/output_merge
├── 00000.png
├── 00000.txt
├── 00001.png
├── 00001.txt
・・・
├── 02999.png
├── 02999.txt
├── 03000.png
└── 03000.txt

こちらは、背景と合成してデータセット作成している様子です。

合成に使用したプログラムです。

png_2_merge_data.py
"""
SAMで切り出した画像データと背景を組み合わせてデータセットを作成する
"""

import json
import glob
import random
import os
import shutil
import math
import numpy as np
import cv2
from PIL import Image

MAX = 3000  # 生成する画像数

CLASS_NAME = ["KINOKO", "TAKENOKO"]
COLORS = [(0, 0, 175), (175, 0, 0)]

BACKGROUND_IMAGE_PATH = "./dataset/background_images"
TARGET_IMAGE_PATH = "./dataset/output_png"
OUTPUT_PATH = "./dataset/output_merge"


BASE_WIDTH = 200  # 商品の基本サイズは、背景画像とのバランスより、横幅を200を基準とする
BACK_WIDTH = 640  # 背景画像ファイルのサイズを合わせる必要がある
BACK_HEIGHT = 480  # 背景画像ファイルのサイズを合わせる必要がある


# 背景画像取得クラス
class Background:
    def __init__(self, backPath):
        self.__backPath = backPath

    def get(self):
        imagePath = random.choice(glob.glob(self.__backPath + "/*.jpg"))
        return cv2.imread(imagePath, cv2.IMREAD_UNCHANGED)


# 検出対象取得クラス (base_widthで指定された横幅を基準にリサイズされる)
class Target:
    def __init__(self, target_path, base_width, class_name):
        self.__target_path = target_path
        self.__base_width = base_width
        self.__class_name = class_name

    def get(self, class_id):
        # 切り出し画像
        class_name = self.__class_name[class_id]
        image_path = random.choice(
            glob.glob(self.__target_path + "/" + class_name + "/*_t.png")
        )
        target_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)

        # 基準(横)サイズに基づきリサイズされる
        h, w, _ = target_image.shape
        aspect = h / w
        target_image = cv2.resize(
            target_image, (int(self.__base_width * aspect), self.__base_width)
        )
        bai = self.__base_width / w

        # labelも基準サイズへのリサイズに併せて、変換する( x * bai )
        target_label = ""

        label_text = ""
        label_path = image_path.replace("_t.png", "_l.txt")
        with open(label_path, encoding="utf-8") as f:
            label_text = f.read()

        orign_label_list = label_text.split(",")
        for label in orign_label_list:
            if label != "":
                target_label += "{},".format(float(label) * bai)

        return target_image, target_label


# 変換クラス
class Transformer:
    def __init__(self, width, height):
        self.__width = width
        self.__height = height
        self.__min_scale = 0.3
        self.__max_scale = 1

    def warp(self, target_image):
        # サイズ変更
        target_image, scale = self.__resize(target_image)

        # 配置位置決定
        h, w, _ = target_image.shape
        left = random.randint(0, self.__width - w)
        top = random.randint(0, self.__height - h)
        rect = ((left, top), (left + w, top + h))

        # 背景面との合成
        new_image = self.__synthesize(target_image, left, top)
        return (new_image, rect, scale)

    def __resize(self, img):
        scale = random.uniform(self.__min_scale, self.__max_scale)
        w, h, _ = img.shape
        return cv2.resize(img, (int(w * scale), int(h * scale))), scale

    def __rote(self, target_image, angle):
        h, w, _ = target_image.shape
        rate = h / w
        scale = 1
        if rate < 0.9 or 1.1 < rate:
            scale = 0.9
        elif rate < 0.8 or 1.2 < rate:
            scale = 0.6
        center = (int(w / 2), int(h / 2))
        trans = cv2.getRotationMatrix2D(center, angle, scale)
        return cv2.warpAffine(target_image, trans, (w, h))

    def __synthesize(self, target_image, left, top):
        background_image = np.zeros((self.__height, self.__width, 4), np.uint8)
        back_pil = Image.fromarray(background_image)
        front_pil = Image.fromarray(target_image)
        back_pil.paste(front_pil, (left, top), front_pil)
        return np.array(back_pil)


class Effecter:

    # Gauss
    def gauss(self, img, level):
        return cv2.blur(img, (level * 2 + 1, level * 2 + 1))

    # Noise
    def noise(self, img):
        img = img.astype("float64")
        img[:, :, 0] = self.__single_channel_noise(img[:, :, 0])
        img[:, :, 1] = self.__single_channel_noise(img[:, :, 1])
        img[:, :, 2] = self.__single_channel_noise(img[:, :, 2])
        return img.astype("uint8")

    def __single_channel_noise(self, single):
        diff = 255 - single.max()
        noise = np.random.normal(0, random.randint(1, 100), single.shape)
        noise = (noise - noise.min()) / (noise.max() - noise.min())
        noise = diff * noise
        noise = noise.astype(np.uint8)
        dst = single + noise
        return dst


# バウンディングボックス描画
def box(frame, rect, class_id):
    ((x1, y1), (x2, y2)) = rect
    label = "{}".format(CLASS_NAME[class_id])
    img = cv2.rectangle(frame, (x1, y1), (x2, y2), COLORS[class_id], 2)
    img = cv2.rectangle(img, (x1, y1), (x1 + 150, y1 - 20), COLORS[class_id], -1)
    cv2.putText(
        img,
        label,
        (x1 + 2, y1 - 2),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.5,
        (255, 255, 255),
        1,
        cv2.LINE_AA,
    )
    return img


# 背景と商品の合成
def marge_image(background_image, front_image):
    back_pil = Image.fromarray(background_image)
    front_pil = Image.fromarray(front_image)
    back_pil.paste(front_pil, (0, 0), front_pil)
    return np.array(back_pil)


# 1画像分のデータを保持するクラス
class MergeData:
    def __init__(self, rate):
        self.__rects = []
        self.__images = []
        self.__class_ids = []
        self.__rate = rate

    def get_class_ids(self):
        return self.__class_ids

    def max(self):
        return len(self.__rects)

    def get(self, i):
        return (self.__images[i], self.__rects[i], self.__class_ids[i])

    # 追加(重複率が指定値以上の場合は失敗する)
    def append(self, target_image, rect, class_id):
        conflict = False
        for i in range(len(self.__rects)):
            iou = self.__multiplicity(self.__rects[i], rect)
            if iou > self.__rate:
                conflict = True
                break
        if conflict == False:
            self.__rects.append(rect)
            self.__images.append(target_image)
            self.__class_ids.append(class_id)
            return True
        return False

    # 重複率
    def __multiplicity(self, a, b):
        (ax_mn, ay_mn) = a[0]
        (ax_mx, ay_mx) = a[1]
        (bx_mn, by_mn) = b[0]
        (bx_mx, by_mx) = b[1]
        a_area = (ax_mx - ax_mn + 1) * (ay_mx - ay_mn + 1)
        b_area = (bx_mx - bx_mn + 1) * (by_mx - by_mn + 1)
        abx_mn = max(ax_mn, bx_mn)
        aby_mn = max(ay_mn, by_mn)
        abx_mx = min(ax_mx, bx_mx)
        aby_mx = min(ay_mx, by_mx)
        w = max(0, abx_mx - abx_mn + 1)
        h = max(0, aby_mx - aby_mn + 1)
        intersect = w * h
        return intersect / (a_area + b_area - intersect)


# 各クラスのデータ数が同一になるようにカウントする
class Counter:
    def __init__(self, max):
        self.__counter = np.zeros(max)

    def get(self):
        n = np.argmin(self.__counter)
        return int(n)

    def inc(self, index):
        self.__counter[index] += 1

    def print(self):
        print(self.__counter)


def main():

    # 出力先の初期化
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)
    os.mkdir(OUTPUT_PATH)

    target = Target(TARGET_IMAGE_PATH, BASE_WIDTH, CLASS_NAME)
    background = Background(BACKGROUND_IMAGE_PATH)

    transformer = Transformer(BACK_WIDTH, BACK_HEIGHT)
    # manifest = Manifest(CLASS_NAME)
    counter = Counter(len(CLASS_NAME))
    effecter = Effecter()

    no = 0

    while True:
        # 背景画像の取得
        background_image = background.get()
        height, width, _ = background_image.shape
        # mergeデータ
        merge_data = MergeData(0.1)
        label = ""
        for _ in range(20):
            # 現時点で作成数の少ないクラスIDを取得
            class_id = counter.get()
            # 切り出しデータの取得
            target_image, targhet_label = target.get(class_id)
            # 変換
            (transform_image, rect, scale) = transformer.warp(target_image)
            frame = marge_image(background_image, transform_image)

            # 商品の追加(重複した場合は、失敗する)
            ret = merge_data.append(transform_image, rect, class_id)
            if ret:
                counter.inc(class_id)

                label += "{}".format(class_id)

                for i, l in enumerate(targhet_label.split(",")):
                    if l != "":
                        # 変換したscaleに併せてLabelも変換する
                        z = int(float(l) * scale)
                        # 貼り付け位置へシフトと、全体画像を基準としてた0..1の正規化
                        if i % 2 == 0:
                            # X座標
                            x = z + rect[0][0]
                            label += " {:.5f}".format(x / width)
                        else:
                            # Y座標
                            y = z + rect[0][1]
                            label += " {:.5f}".format(y / height)

                label += "\n"

        print("max:{}".format(merge_data.max()))
        frame = background_image
        for index in range(merge_data.max()):
            (target_image, _, _) = merge_data.get(index)
            # 合成
            frame = marge_image(frame, target_image)

        # アルファチャンネル削除
        frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)

        # エフェクト
        frame = effecter.gauss(frame, random.randint(0, 2))
        frame = effecter.noise(frame)

        # 画像名
        png_file_name = "{:05d}.png".format(no)
        # Segmentpoint
        label_file_name = "{:05d}.txt".format(no)

        no += 1

        # 画像保存
        cv2.imwrite("{}/{}".format(OUTPUT_PATH, png_file_name), frame)

        # テキスト保存
        with open("{}/{}".format(OUTPUT_PATH, label_file_name), mode="w") as f:
            f.write(label)

        # manifest追加
        # manifest.appned(fileName, merge_data, frame.shape[0], frame.shape[1])

        for i in range(merge_data.max()):
            (_, rect, class_id) = merge_data.get(i)
            # バウンディングボックス描画(確認用)
            frame = box(frame, rect, class_id)

        counter.print()
        print("no:{}".format(no))
        if MAX <= no:
            break

        # 表示(確認用)
        cv2.imshow("frame", frame)
        cv2.waitKey(1)


main()

(4) YOLOv8形式への変換

YOLOv8で学習させるためには、学習用と検証用のデータが必要ですが、上記で作成したデータを8対2に分割して、保存します。

dataset/yolo_data/
├── train
│   ├── images
│   │   ├── 00000.png
│   │   ├── 00001.png
...
│   │   ├── 02998.png
│   │   └── 02999.png
│   └── labels
│       ├── 00000.txt
│       ├── 00001.txt
...
│       ├── 02998.txt
│       └── 02999.txt
└── val
    ├── images
    │   ├── 00002.png
    │   ├── 00006.png
...
    │   ├── 02989.png
    │   └── 02992.png
    └── labels
        ├── 00002.txt
        ├── 00006.txt
...
        ├── 02989.txt
        └── 02992.txt

そして、上記のパスを記述したyamlファイルも作成し、学習用のデータセットとします。(クラス名が決定されるのは、この時点となります)

data.yaml

path: /home/dataset/yolo_data
train: train/images
val: val/images


names:
  0: KINOKO
  1: TAKENOKO

YOLOv8形式への変換に使用したプログラムです。

merge_data_to_yolo_data.py
"""
背景と組み合わせたデータをyolo形式に変換する
"""

import glob
import os
import shutil
import numpy as np
from PIL import Image


def main():
    input_path = "./dataset/output_merge"
    output_path = "./dataset/yolo_data"
    stages = ["train", "val"]

    train_path = "{}/train".format(output_path)
    train_images_path = "{}/images".format(train_path)
    train_labels_path = "{}/labels".format(train_path)

    val_path = "{}/val".format(output_path)
    val_images_path = "{}/images".format(val_path)
    val_labels_path = "{}/labels".format(val_path)

    os.makedirs(train_path, exist_ok=True)
    os.makedirs(train_images_path, exist_ok=True)
    os.makedirs(train_labels_path, exist_ok=True)

    os.makedirs(val_path, exist_ok=True)
    os.makedirs(val_images_path, exist_ok=True)
    os.makedirs(val_labels_path, exist_ok=True)

    image_files = glob.glob("{}/*.png".format(input_path))
    for i, input_image_file in enumerate(image_files):
        basename = os.path.splitext(os.path.basename(input_image_file))[0]
        if i % 10 < 8:
            stage = stages[0]
        else:
            stage = stages[1]

        input_label_file = "{}/{}.txt".format(input_path, basename)
        output_label_file = "{}/{}/labels/{}.txt".format(output_path, stage, basename)
        output_image_file = "{}/{}/images/{}.png".format(output_path, stage, basename)
        shutil.copy(input_label_file, output_label_file)
        shutil.copy(input_image_file, output_image_file)


if __name__ == "__main__":
    main()

3 学習

学習は、下記の要領です。

from ultralytics import YOLO

model = YOLO("yolov8n-seg.pt")

model.train(data="./dataset/yolo_data/data.yaml", epochs=10, batch=8, workers=4)

正常に学習が終わると、runs/segment/train/weights/best.ptに生成されたモデルが保存されます。

$ tree runs
runs
└── segment
    └── train
        ├── args.yaml
        ├── BoxF1_curve.png
        ├── BoxP_curve.png
        ├── BoxPR_curve.png
        ├── BoxR_curve.png
        ├── confusion_matrix_normalized.png
        ├── confusion_matrix.png
        ├── labels_correlogram.jpg
        ├── labels.jpg
        ├── MaskF1_curve.png
        ├── MaskP_curve.png
        ├── MaskPR_curve.png
        ├── MaskR_curve.png
        ├── results.csv
        ├── results.png
        ├── train_batch0.jpg
        ├── train_batch1.jpg
        ├── train_batch2.jpg
        ├── val_batch0_labels.jpg
        ├── val_batch0_pred.jpg
        ├── val_batch1_labels.jpg
        ├── val_batch1_pred.jpg
        ├── val_batch2_labels.jpg
        ├── val_batch2_pred.jpg
        └── weights
            ├── best.pt
            └── last.pt

4 推論

best.ptを使用して、Webカメラの画像を推論している様子です。

この時使用したプログラムです。

inference_webcam.py
# 参考にさせて頂きました
# https://github.com/ultralytics/ultralytics/issues/561

from ultralytics import YOLO
import numpy as np
import cv2

model = YOLO("./best.pt")

colors = [(0, 0, 200), (0, 200, 0)]
linewidth = 5
fontScale = 1.5
fontFace = cv2.FONT_HERSHEY_SIMPLEX
thickness = 5


def overlay(image, mask, color, alpha, resize=None):
    colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
    colored_mask = np.moveaxis(colored_mask, 0, -1)
    masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
    image_overlay = masked.filled()
    if resize is not None:
        image = cv2.resize(image.transpose(1, 2, 0), resize)
        image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)

    image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
    return image_combined


def label(box, img, color, label, line_thickness=3):
    x1 = int(box[0])
    y1 = int(box[1])
    x2 = int(box[2])
    y2 = int(box[3])

    text_size = cv2.getTextSize(
        label, 0, fontScale=fontScale, thickness=line_thickness
    )[0]
    cv2.rectangle(
        img, (x1, y1), (x1 + text_size[0], y1 - text_size[1] - 2), color, -1
    )  # fill
    cv2.putText(
        img,
        label,
        (x1, y1 - 3),
        fontFace,
        fontScale,
        [225, 255, 255],
        thickness=line_thickness,
        lineType=cv2.LINE_AA,
    )

    cv2.rectangle(
        img,
        (x1, y1),
        (x2, y2),
        color,
        linewidth,
    )


cap = cv2.VideoCapture(0)
if cap.isOpened() is False:
    raise IOError

while True:
    try:
        ret, frame = cap.read()
        if ret is False:
            raise IOError

        h, w, _ = frame.shape

        results = model(frame, conf=0.80, iou=0.5)
        result = results[0]
        image = frame.copy()

        if result.masks is not None:
            for r in results:
                boxes = r.boxes
                conf_list = r.boxes.conf.tolist()

            for i, (seg, box) in enumerate(zip(result.masks.data.cpu().numpy(), boxes)):

                seg = cv2.resize(seg, (w, h))
                image = overlay(image, seg, colors[int(box.cls)], 0.5)

                class_id = int(box.cls)
                box = box.xyxy.tolist()[0]
                class_name = result.names[class_id]
                label(
                    box,
                    image,
                    colors[class_id],
                    "{} {:.2f}".format(class_name, conf_list[i]),
                    line_thickness=3,
                )

        frame = cv2.resize(frame, (640, 480))
        image = cv2.resize(image, (640, 480))
        mergeImg = np.hstack((frame, image))

        cv2.imshow("YOLO", mergeImg)
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
    except KeyboardInterrupt:
        break

cap.release()
cv2.destroyAllWindows()

参考に、./input に置かれた画像を対象に、推論するプログラムも置いておきます。

inference_images.py
# 参考にさせて頂きました
# https://github.com/ultralytics/ultralytics/issues/561

from ultralytics import YOLO
import numpy as np
import cv2
import os
import glob

model = YOLO("./best.pt")

colors = [(0, 0, 200), (0, 200, 0)]
linewidth = 5
fontScale = 1.5
fontFace = cv2.FONT_HERSHEY_SIMPLEX
thickness = 5

input_path = "./input"
output_path = "./output"


def overlay(image, mask, color, alpha, resize=None):
    colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
    colored_mask = np.moveaxis(colored_mask, 0, -1)
    masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
    image_overlay = masked.filled()
    if resize is not None:
        image = cv2.resize(image.transpose(1, 2, 0), resize)
        image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)

    image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
    return image_combined


def label(box, img, color, label, line_thickness=3):
    x1 = int(box[0])
    y1 = int(box[1])
    x2 = int(box[2])
    y2 = int(box[3])

    text_size = cv2.getTextSize(
        label, 0, fontScale=fontScale, thickness=line_thickness
    )[0]
    cv2.rectangle(
        img, (x1, y1), (x1 + text_size[0], y1 - text_size[1] - 2), color, -1
    )  # fill
    cv2.putText(
        img,
        label,
        (x1, y1 - 3),
        fontFace,
        fontScale,
        [225, 255, 255],
        thickness=line_thickness,
        lineType=cv2.LINE_AA,
    )

    cv2.rectangle(
        img,
        (x1, y1),
        (x2, y2),
        color,
        linewidth,
    )


def main():
    os.makedirs(output_path, exist_ok=True)

    files = glob.glob("{}/*.png".format(input_path))
    for file in files:
        basename = os.path.splitext(os.path.basename(file))[0]
        image = cv2.imread(file)

        h, w, _ = image.shape

        results = model(image, conf=0.60, iou=0.5)
        result = results[0]

        if result.masks is not None:
            for r in results:
                boxes = r.boxes
                conf_list = r.boxes.conf.tolist()

            for i, (seg, box) in enumerate(zip(result.masks.data.cpu().numpy(), boxes)):

                seg = cv2.resize(seg, (w, h))
                image = overlay(image, seg, colors[int(box.cls)], 0.5)

                class_id = int(box.cls)
                box = box.xyxy.tolist()[0]
                class_name = result.names[class_id]
                label(
                    box,
                    image,
                    colors[class_id],
                    "{} {:.2f}".format(class_name, conf_list[i]),
                    line_thickness=3,
                )
        output_filename = "{}/{}.png".format(output_path, basename)
        print(output_filename)
        cv2.imwrite(output_filename, image)


main()

5 最後に

今回は、YOLOv8のセグメンテーションを試してみました。 データセットさえ作れれば、インスタンスセグメンテーションのファインチューニングも、それほど難しくないのかなと言うのが、今回作業した印象です。

やはり、問題は、如何に効率よく精度の高いデータを大量に用意するかでしょう。 SAM(Segment Anything Model)は、やはりここでもゲームチェンジャーだと思います。

やっと、完成したので、全部食べます。 どちらも好きですー \(^o^)/