1 はじめに
CX 事業本部 delivery部の平内(SIN)です。
Meta社による Segment Anything Model(SAM)は、セグメンテーションのための汎用モデルで、ファインチューニングなしで、あらゆる物体がセグメンテーションできます。
前回は、これを使用してUSBカメラからの入力をセグメンテーションしてみました。
今回は、「コンピュータビジョン用のデータセット画像」の切り出しに焦点をあてて、事前に録画した動画から、画像を切り出してみました。
最初に動作している様子です。
動画を読み込むと最初のフレームで停止し、対象オブジェクトの指定が行えます。マウスで、対象を囲むと、その後は、そのオブジェクトを追従しながらデータを切り出して保存します。(動画の後半では、同じ動画で、別のアヒルを抽出しています)
抽出された画像は、背景が白と透過の2種類となります。
2 Object masks from prompts
今回も、使用しているのは、Object masks from promptsです。
input_boxに検出範囲を指定することで、特定のオブジェクトが対象となるようになっています。
self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
3 対象の選択
最初のフレームで、マウスを使用してオブジェクト検出範囲を指定するコードです。 matplotlib.pyplotで画像を表示し、マウスの操作をトラップしています。
# マウスで範囲指定する
class BoundingBox:
def __init__(self, image):
self.x1 = -1
self.x2 = -1
self.y1 = -1
self.y2 = -1
self.image = image.copy()
plt.figure()
plt.connect("motion_notify_event", self.motion)
plt.connect("button_press_event", self.press)
plt.connect("button_release_event", self.release)
self.ln_v = plt.axvline(0)
self.ln_h = plt.axhline(0)
plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
plt.show()
# 選択中のカーソル表示
def motion(self, event):
if event.xdata is not None and event.ydata is not None:
self.ln_v.set_xdata(event.xdata)
self.ln_h.set_ydata(event.ydata)
self.x2 = event.xdata.astype("int16")
self.y2 = event.ydata.astype("int16")
if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
plt.clf()
plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
ax = plt.gca()
rect = patches.Rectangle(
(self.x1, self.y1),
self.x2 - self.x1,
self.y2 - self.y1,
angle=0.0,
fill=False,
edgecolor="#00FFFF",
)
ax.add_patch(rect)
plt.draw()
# ドラッグ開始位置
def press(self, event):
self.x1 = event.xdata.astype("int16")
self.y1 = event.ydata.astype("int16")
# ドラッグ終了位置、表示終了
def release(self, event):
plt.clf()
plt.close()
def get_area(self):
return int(self.x1), int(self.y1), int(self.x2), int(self.y2)
当初、十字のカーソルが表示され、ドラッグ開始以降は、選択された矩形を表示します。また、クリックが離された時点で、画像の表示は終了します。
4 ノイズ除去
動画で撮影した場合、対象のオブジェクトが、最初に指定した範囲から移動・拡大・縮小する可能性があります。
そこで、取得したマスクの上下左右15%拡大した範囲を、次のフレームの抽出範囲として順次使用しています。これにより、ある程度の追跡が可能となっています。
しかし、ここで問題となるのは、ノイズです。検出されたマスクは、その状況により、やや、ノイズが入ったものとなることがあり、このノイズの入ったマスクを基準にすると、次のフレームの指定範囲が、対象オブジェクトより、大きなものとなってしまい、結果的に、対象以外が検出されてしまうことになります。
そこで、取得したマスクは、以下のような手順で、ノイズを除去しています。
- 取得したマスクを、2値画像に展開
- 2値画像から輪郭を取得する
- 最大面積の輪郭のみを使用して新たに2値画像を生成
- 上記の2値画像からマスクを再構成
# ノイズ除去
def _remove_noise(self, image, mask):
# 2値画像(白及び黒)を生成する
height, width, _ = image.shape
tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
# マスクによって黒画像の上に白を描画する
tmp_black_image[:] = np.where(
mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
)
# 輪郭の取得
contours, _ = cv2.findContours(
tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# 最も面積が大きい輪郭を選択
max_contours = max(contours, key=lambda x: cv2.contourArea(x))
# 黒画面に一番大きい輪郭だけ塗りつぶして描画する
black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
black_image = cv2.drawContours(
black_image, [max_contours], -1, color=255, thickness=-1
)
# 輪郭を保存
self._contours, _ = cv2.findContours(
black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# マスクを作り直す
new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
new_mask = np.squeeze(new_mask)
return new_mask
5 抽出画像
抽出された画像は、output_pathで指定されたフォルダの下に、フレーム番号で保存されます。
- 0000000001_w.png(背景が白の画像)
- 0000000001_t.png(背景が透過の画像)
背景が白の画像は、分類モデルのデータセットに利用できると思います。また、透過の画像は、別途用意した背景の上に重ねて検出モデル用のデータセットが生成できると思います。
参考:下記では、透過画像からYOLOv5のデータセットを生成しています。
6 最後に
今回は、Segment Anythingを使用して、コンピュータビジョン用のデータセット画像を生成してみました。
これにより、軽易に動画を撮影するだけで、後は半自動で画像の切り出しまでできるようになりました。
この作業が、簡単になれば、手返し良くデータセットを試せるので、モデルの精度を上げることに貢献できると、個人的には信じています。
動画で動作していたソースコードは、以下です。説明が不足している部分については、こちらを参照頂ければ幸いです。
index.py
import os
import datetime
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
from matplotlib import patches
from segment_anything import sam_model_registry, SamPredictor
# マウスで範囲指定する
class BoundingBox:
def __init__(self, image):
self.x1 = -1
self.x2 = -1
self.y1 = -1
self.y2 = -1
self.image = image.copy()
plt.figure()
plt.connect("motion_notify_event", self.motion)
plt.connect("button_press_event", self.press)
plt.connect("button_release_event", self.release)
self.ln_v = plt.axvline(0)
self.ln_h = plt.axhline(0)
plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
plt.show()
# 選択中のカーソル表示
def motion(self, event):
if event.xdata is not None and event.ydata is not None:
self.ln_v.set_xdata(event.xdata)
self.ln_h.set_ydata(event.ydata)
self.x2 = event.xdata.astype("int16")
self.y2 = event.ydata.astype("int16")
if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
plt.clf()
plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
ax = plt.gca()
rect = patches.Rectangle(
(self.x1, self.y1),
self.x2 - self.x1,
self.y2 - self.y1,
angle=0.0,
fill=False,
edgecolor="#00FFFF",
)
ax.add_patch(rect)
plt.draw()
# ドラッグ開始位置
def press(self, event):
self.x1 = event.xdata.astype("int16")
self.y1 = event.ydata.astype("int16")
# ドラッグ終了位置、表示終了
def release(self, event):
plt.clf()
plt.close()
def get_area(self):
return int(self.x1), int(self.y1), int(self.x2), int(self.y2)
# SAM
class SegmentAnything:
def __init__(self, device, model_type, sam_checkpoint):
print("init Segment Anything")
self.device = device
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(self.device)
self.predictor = SamPredictor(sam)
@property
def contours(self):
return self._contours
@property
def transparent_image(self):
return self._transparent_image
@property
def white_back_image(self):
return self._white_back_image
@property
def box(self):
return self._box
# マスク取得
def predict(self, frame, input_box):
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
# ノイズ除去
self._mask = self._remove_noise(frame, masks[0])
# 範囲取得
self._box = self._get_box()
# 部分画像取得
self._white_back_image, self._transparent_image = self._get_extract_image(frame)
# マスクの範囲取得
def _get_box(self):
mask_indexes = np.where(self._mask)
y_min = np.min(mask_indexes[0])
y_max = np.max(mask_indexes[0])
x_min = np.min(mask_indexes[1])
x_max = np.max(mask_indexes[1])
return np.array([x_min, y_min, x_max, y_max])
# ノイズ除去
def _remove_noise(self, image, mask):
# 2値画像(白及び黒)を生成する
height, width, _ = image.shape
tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
# マスクによって黒画像の上に白を描画する
tmp_black_image[:] = np.where(
mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
)
# 輪郭の取得
contours, _ = cv2.findContours(
tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# 最も面積が大きい輪郭を選択
max_contours = max(contours, key=lambda x: cv2.contourArea(x))
# 黒画面に一番大きい輪郭だけ塗りつぶして描画する
black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
black_image = cv2.drawContours(
black_image, [max_contours], -1, color=255, thickness=-1
)
# 輪郭を保存
self._contours, _ = cv2.findContours(
black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# マスクを作り直す
new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
new_mask = np.squeeze(new_mask)
return new_mask
# 部分イメージの取得
def _get_extract_image(self, image):
# boxの範囲でマスクを切り取る
part_of_mask = self._mask[
self._box[1] : self._box[3], self._box[0] : self._box[2]
]
# boxの範囲で元画像を切り取る
copy_image = image.copy() # 個々の食品を切取るためのテンポラリ画像
white_back_image = copy_image[
self._box[1] : self._box[3], self._box[0] : self._box[2]
]
# boxの範囲で白一色の2値画像を作成する
h = self._box[3] - self._box[1]
w = self._box[2] - self._box[0]
white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8)
# マスクによって白画像の上に元画像を描画する
white_back_image[:] = np.where(
part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image
)
transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA)
transparent_image[np.logical_not(part_of_mask), 3] = 0
return white_back_image, transparent_image
class Video:
def __init__(self, filename):
self.cap = cv2.VideoCapture(filename)
if self.cap.isOpened() == False:
print("Video open faild.")
else:
self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 次のフレーム取得
def next_frame(self):
return self.cap.read()
## 総フレーム数
@property
def frame_max(self):
return self._frame_max
def destroy(self):
print("video destroy.")
self.cap.release()
cv2.destroyAllWindows()
# 縦横それぞれ、0.15倍まで広げた、ボックスを取得する
def get_next_input(box):
x1 = box[0]
y1 = box[1]
x2 = box[2]
y2 = box[3]
w = x2 - x1
h = y2 - y1
x_margen = int(w * 0.15)
y_margen = int(h * 0.15)
return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen])
def main():
print("PyTorch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
step = 3
start_frame = 0 # 684から乱れる
# filename = "DuckBrothers2.mp4"
filename = "post_1.mp4"
output_path = "./output"
basename = os.path.splitext(os.path.basename(filename))[0]
os.makedirs("{}/{}".format(output_path, basename), exist_ok=True)
video = Video(filename)
# Segment Anything
sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth")
try:
print("start")
for i in range(video.frame_max):
ret, frame = video.next_frame()
if ret == False:
continue
# 開始位置まで読み飛ばす
if i < start_frame:
continue
# フレーム省略
if i % step != 0:
continue
# 最初のフレームで、バウンディングボックスを取得する
if i == start_frame:
bounding_box = BoundingBox(frame)
x1, y1, x2, y2 = bounding_box.get_area()
input_box = np.array([x1, y1, x2, y2])
print(
"{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format(
datetime.datetime.now(),
filename,
frame.shape,
start_frame,
input_box,
i + 1,
video.frame_max,
)
)
# マスク生成
sam.predict(frame, input_box)
# 輪郭描画
frame = cv2.drawContours(
frame, sam.contours, -1, color=[255, 255, 0], thickness=6
)
# バウンディングボックス描画
frame = cv2.rectangle(
frame,
pt1=(input_box[0], input_box[1]),
pt2=(input_box[2], input_box[3]),
color=(255, 255, 255),
thickness=2,
lineType=cv2.LINE_4,
)
# データ保存
cv2.imwrite(
"{}/{}/{:09}_t.png".format(output_path, basename, i),
sam.transparent_image,
)
cv2.imwrite(
"{}/{}/{:09}_w.png".format(output_path, basename, i),
sam.white_back_image,
)
# 表示
cv2.imshow("Extract", sam.white_back_image)
cv2.waitKey(1)
cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3))
cv2.waitKey(1)
# 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する
input_box = get_next_input(sam.box)
except KeyboardInterrupt:
video.destroy()
if __name__ == "__main__":
main()