Segment Anythingを使用して動画からデータセット用の画像を切り出してみました

2023.05.07

1 はじめに

CX 事業本部 delivery部の平内(SIN)です。

Meta社による Segment Anything Model(SAM)は、セグメンテーションのための汎用モデルで、ファインチューニングなしで、あらゆる物体がセグメンテーションできます。

前回は、これを使用してUSBカメラからの入力をセグメンテーションしてみました。

今回は、「コンピュータビジョン用のデータセット画像」の切り出しに焦点をあてて、事前に録画した動画から、画像を切り出してみました。

最初に動作している様子です。

動画を読み込むと最初のフレームで停止し、対象オブジェクトの指定が行えます。マウスで、対象を囲むと、その後は、そのオブジェクトを追従しながらデータを切り出して保存します。(動画の後半では、同じ動画で、別のアヒルを抽出しています)

抽出された画像は、背景が白と透過の2種類となります。

2 Object masks from prompts

今回も、使用しているのは、Object masks from promptsです。

input_boxに検出範囲を指定することで、特定のオブジェクトが対象となるようになっています。

self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

3 対象の選択

最初のフレームで、マウスを使用してオブジェクト検出範囲を指定するコードです。 matplotlib.pyplotで画像を表示し、マウスの操作をトラップしています。

# マウスで範囲指定する
class BoundingBox:
    def __init__(self, image):
        self.x1 = -1
        self.x2 = -1
        self.y1 = -1
        self.y2 = -1
        self.image = image.copy()
        plt.figure()
        plt.connect("motion_notify_event", self.motion)
        plt.connect("button_press_event", self.press)
        plt.connect("button_release_event", self.release)
        self.ln_v = plt.axvline(0)
        self.ln_h = plt.axhline(0)
        plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
        plt.show()

    # 選択中のカーソル表示
    def motion(self, event):
        if event.xdata is not None and event.ydata is not None:
            self.ln_v.set_xdata(event.xdata)
            self.ln_h.set_ydata(event.ydata)
            self.x2 = event.xdata.astype("int16")
            self.y2 = event.ydata.astype("int16")
        if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
            plt.clf()
            plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
            ax = plt.gca()
            rect = patches.Rectangle(
                (self.x1, self.y1),
                self.x2 - self.x1,
                self.y2 - self.y1,
                angle=0.0,
                fill=False,
                edgecolor="#00FFFF",
            )
            ax.add_patch(rect)
        plt.draw()

    # ドラッグ開始位置
    def press(self, event):
        self.x1 = event.xdata.astype("int16")
        self.y1 = event.ydata.astype("int16")

    # ドラッグ終了位置、表示終了
    def release(self, event):
        plt.clf()
        plt.close()

    def get_area(self):
        return int(self.x1), int(self.y1), int(self.x2), int(self.y2)

当初、十字のカーソルが表示され、ドラッグ開始以降は、選択された矩形を表示します。また、クリックが離された時点で、画像の表示は終了します。

4 ノイズ除去

動画で撮影した場合、対象のオブジェクトが、最初に指定した範囲から移動・拡大・縮小する可能性があります。

そこで、取得したマスクの上下左右15%拡大した範囲を、次のフレームの抽出範囲として順次使用しています。これにより、ある程度の追跡が可能となっています。

しかし、ここで問題となるのは、ノイズです。検出されたマスクは、その状況により、やや、ノイズが入ったものとなることがあり、このノイズの入ったマスクを基準にすると、次のフレームの指定範囲が、対象オブジェクトより、大きなものとなってしまい、結果的に、対象以外が検出されてしまうことになります。

そこで、取得したマスクは、以下のような手順で、ノイズを除去しています。

  • 取得したマスクを、2値画像に展開
  • 2値画像から輪郭を取得する
  • 最大面積の輪郭のみを使用して新たに2値画像を生成
  • 上記の2値画像からマスクを再構成

# ノイズ除去
def _remove_noise(self, image, mask):
    # 2値画像(白及び黒)を生成する
    height, width, _ = image.shape
    tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
    tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
    # マスクによって黒画像の上に白を描画する
    tmp_black_image[:] = np.where(
        mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
    )

    # 輪郭の取得
    contours, _ = cv2.findContours(
        tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    # 最も面積が大きい輪郭を選択
    max_contours = max(contours, key=lambda x: cv2.contourArea(x))
    # 黒画面に一番大きい輪郭だけ塗りつぶして描画する
    black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
    black_image = cv2.drawContours(
        black_image, [max_contours], -1, color=255, thickness=-1
    )
    # 輪郭を保存
    self._contours, _ = cv2.findContours(
        black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )

    # マスクを作り直す
    new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
    new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
    new_mask = np.squeeze(new_mask)

    return new_mask

5 抽出画像

抽出された画像は、output_pathで指定されたフォルダの下に、フレーム番号で保存されます。

  • 0000000001_w.png(背景が白の画像)
  • 0000000001_t.png(背景が透過の画像)

背景が白の画像は、分類モデルのデータセットに利用できると思います。また、透過の画像は、別途用意した背景の上に重ねて検出モデル用のデータセットが生成できると思います。

参考:下記では、透過画像からYOLOv5のデータセットを生成しています。

6 最後に

今回は、Segment Anythingを使用して、コンピュータビジョン用のデータセット画像を生成してみました。

これにより、軽易に動画を撮影するだけで、後は半自動で画像の切り出しまでできるようになりました。

この作業が、簡単になれば、手返し良くデータセットを試せるので、モデルの精度を上げることに貢献できると、個人的には信じています。

動画で動作していたソースコードは、以下です。説明が不足している部分については、こちらを参照頂ければ幸いです。

index.py
import os
import datetime
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
from matplotlib import patches
from segment_anything import sam_model_registry, SamPredictor


# マウスで範囲指定する
class BoundingBox:
    def __init__(self, image):
        self.x1 = -1
        self.x2 = -1
        self.y1 = -1
        self.y2 = -1
        self.image = image.copy()
        plt.figure()
        plt.connect("motion_notify_event", self.motion)
        plt.connect("button_press_event", self.press)
        plt.connect("button_release_event", self.release)
        self.ln_v = plt.axvline(0)
        self.ln_h = plt.axhline(0)
        plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
        plt.show()

    # 選択中のカーソル表示
    def motion(self, event):
        if event.xdata is not None and event.ydata is not None:
            self.ln_v.set_xdata(event.xdata)
            self.ln_h.set_ydata(event.ydata)
            self.x2 = event.xdata.astype("int16")
            self.y2 = event.ydata.astype("int16")
        if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
            plt.clf()
            plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
            ax = plt.gca()
            rect = patches.Rectangle(
                (self.x1, self.y1),
                self.x2 - self.x1,
                self.y2 - self.y1,
                angle=0.0,
                fill=False,
                edgecolor="#00FFFF",
            )
            ax.add_patch(rect)
        plt.draw()

    # ドラッグ開始位置
    def press(self, event):
        self.x1 = event.xdata.astype("int16")
        self.y1 = event.ydata.astype("int16")

    # ドラッグ終了位置、表示終了
    def release(self, event):
        plt.clf()
        plt.close()

    def get_area(self):
        return int(self.x1), int(self.y1), int(self.x2), int(self.y2)


# SAM
class SegmentAnything:
    def __init__(self, device, model_type, sam_checkpoint):
        print("init Segment Anything")
        self.device = device
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(self.device)
        self.predictor = SamPredictor(sam)

    @property
    def contours(self):
        return self._contours

    @property
    def transparent_image(self):
        return self._transparent_image

    @property
    def white_back_image(self):
        return self._white_back_image

    @property
    def box(self):
        return self._box

    # マスク取得
    def predict(self, frame, input_box):
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        self.predictor.set_image(image)
        masks, _, _ = self.predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )

        # ノイズ除去
        self._mask = self._remove_noise(frame, masks[0])
        # 範囲取得
        self._box = self._get_box()
        # 部分画像取得
        self._white_back_image, self._transparent_image = self._get_extract_image(frame)

    # マスクの範囲取得
    def _get_box(self):
        mask_indexes = np.where(self._mask)
        y_min = np.min(mask_indexes[0])
        y_max = np.max(mask_indexes[0])
        x_min = np.min(mask_indexes[1])
        x_max = np.max(mask_indexes[1])
        return np.array([x_min, y_min, x_max, y_max])

    # ノイズ除去
    def _remove_noise(self, image, mask):
        # 2値画像(白及び黒)を生成する
        height, width, _ = image.shape
        tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
        tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
        # マスクによって黒画像の上に白を描画する
        tmp_black_image[:] = np.where(
            mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
        )

        # 輪郭の取得
        contours, _ = cv2.findContours(
            tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        # 最も面積が大きい輪郭を選択
        max_contours = max(contours, key=lambda x: cv2.contourArea(x))
        # 黒画面に一番大きい輪郭だけ塗りつぶして描画する
        black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
        black_image = cv2.drawContours(
            black_image, [max_contours], -1, color=255, thickness=-1
        )
        # 輪郭を保存
        self._contours, _ = cv2.findContours(
            black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # マスクを作り直す
        new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
        new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
        new_mask = np.squeeze(new_mask)

        return new_mask

    # 部分イメージの取得
    def _get_extract_image(self, image):
        # boxの範囲でマスクを切り取る
        part_of_mask = self._mask[
            self._box[1] : self._box[3], self._box[0] : self._box[2]
        ]
        # boxの範囲で元画像を切り取る
        copy_image = image.copy()  # 個々の食品を切取るためのテンポラリ画像
        white_back_image = copy_image[
            self._box[1] : self._box[3], self._box[0] : self._box[2]
        ]
        # boxの範囲で白一色の2値画像を作成する
        h = self._box[3] - self._box[1]
        w = self._box[2] - self._box[0]
        white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8)
        # マスクによって白画像の上に元画像を描画する
        white_back_image[:] = np.where(
            part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image
        )

        transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA)
        transparent_image[np.logical_not(part_of_mask), 3] = 0
        return white_back_image, transparent_image


class Video:
    def __init__(self, filename):
        self.cap = cv2.VideoCapture(filename)
        if self.cap.isOpened() == False:
            print("Video open faild.")
        else:
            self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 次のフレーム取得
    def next_frame(self):
        return self.cap.read()

    ## 総フレーム数
    @property
    def frame_max(self):
        return self._frame_max

    def destroy(self):
        print("video destroy.")
        self.cap.release()
        cv2.destroyAllWindows()


# 縦横それぞれ、0.15倍まで広げた、ボックスを取得する
def get_next_input(box):
    x1 = box[0]
    y1 = box[1]
    x2 = box[2]
    y2 = box[3]
    w = x2 - x1
    h = y2 - y1
    x_margen = int(w * 0.15)
    y_margen = int(h * 0.15)

    return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen])


def main():
    print("PyTorch version:", torch.__version__)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))

    step = 3
    start_frame = 0  # 684から乱れる
    # filename = "DuckBrothers2.mp4"
    filename = "post_1.mp4"
    output_path = "./output"
    basename = os.path.splitext(os.path.basename(filename))[0]
    os.makedirs("{}/{}".format(output_path, basename), exist_ok=True)

    video = Video(filename)

    # Segment Anything
    sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth")

    try:
        print("start")
        for i in range(video.frame_max):
            ret, frame = video.next_frame()
            if ret == False:
                continue

            # 開始位置まで読み飛ばす
            if i < start_frame:
                continue

            # フレーム省略
            if i % step != 0:
                continue

            # 最初のフレームで、バウンディングボックスを取得する
            if i == start_frame:
                bounding_box = BoundingBox(frame)
                x1, y1, x2, y2 = bounding_box.get_area()
                input_box = np.array([x1, y1, x2, y2])

            print(
                "{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format(
                    datetime.datetime.now(),
                    filename,
                    frame.shape,
                    start_frame,
                    input_box,
                    i + 1,
                    video.frame_max,
                )
            )

            # マスク生成
            sam.predict(frame, input_box)
            # 輪郭描画
            frame = cv2.drawContours(
                frame, sam.contours, -1, color=[255, 255, 0], thickness=6
            )

            # バウンディングボックス描画
            frame = cv2.rectangle(
                frame,
                pt1=(input_box[0], input_box[1]),
                pt2=(input_box[2], input_box[3]),
                color=(255, 255, 255),
                thickness=2,
                lineType=cv2.LINE_4,
            )

            # データ保存
            cv2.imwrite(
                "{}/{}/{:09}_t.png".format(output_path, basename, i),
                sam.transparent_image,
            )
            cv2.imwrite(
                "{}/{}/{:09}_w.png".format(output_path, basename, i),
                sam.white_back_image,
            )

            # 表示
            cv2.imshow("Extract", sam.white_back_image)
            cv2.waitKey(1)
            cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3))
            cv2.waitKey(1)

            # 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する
            input_box = get_next_input(sam.box)
    except KeyboardInterrupt:
        video.destroy()


if __name__ == "__main__":
    main()