1 はじめに
CX 事業本部 delivery部の平内(SIN)です。
Meta社による Segment Anything Model(以下、SAM)は、セグメンテーションのための汎用モデルで、ファインチューニングなしで、あらゆる物体がセグメンテーションできます。
今回は、こちらを使用して、スーパーマーケットで冷蔵庫の缶ビールを検出してみました。
最初に、動作確認している様子です。
缶ビールがセグメンテーションされていることと、それぞれの銘柄が判定できていることを確認できると思います。
2 学習済みモデルによる物体検出 (YOLOv8)
YOLOv8などの物体検出モデルでは、配布されている学習済みモデルで、ある程度の物体検出が可能です。
下記は、テーブルに置かれた、缶ビールが、chair、dining tableとともに、cupとして検出されている様子が確認できます。
しかし、このような学習済みモデルでは、店舗に陳列されている缶ビールを検出するようなことは、出来ませんでした。
3 フィルタによる抽出 (Segment Anything Model)
そこで、使用したのが、SAMです。
パラメータにもよりますが、SAMで検出してみると。なんとなく、缶ビールが検出できそうな予感がしてきます。
上記では、缶ビールの検出もできているようですが、缶上のラベルや、値札などもセグメンテーションされていて、ちょっと、欲しいものとは違う感じです。
そこで、目的としている缶ビール以外のマスク情報をフィルタするために、画像上での缶ビールのサイズを調べて、そのサイズに基づいてフィルタしてみました。
まず、缶ビールのサイズの幅及び高さの最大、最小値を決めます。
# 缶ビールは、画像サイズの何%を占めるか
height, width, _ = image.shape
minH = int(height * 0.14)
maxH = int(height * 0.29)
minW = int(width * 0.04)
maxW = int(width * 0.09)
そして、SAMで検出したすべてのマスクの外接矩形を求め、サイズが範囲内に収まるものだけをフィルタします。
for contour in contours:
# 外接の矩形を求める
x, y, w, h = cv2.boundingRect(contour)
if ( # 一定のサイズのみ検出対象にする
minH <= h and h <= maxH and minW <= w and w <= maxW
):
all_contours.append(contour)
フィルタしたマスクだけを表示したものです。フィルタ前のマスクは300個程度で、フィルタ後は30個程度となりました。
作業時の様子ですが、計算した、min/maxの矩形を画面の左端に表示して、缶ビールのサイズと比べながら、丁度よいフィルタのサイズを試行錯誤しました。
4 分類モデルによる銘柄判定
セグメンテーションされた缶ビールの銘柄を求めるためには、YOLOv8の分類モデルを使用しました。
分類モデルを作成した手順は、以下のとおりです。
- 動画撮影
- png画像の抽出
- データセット
- データ増幅
- ファインチューニング
(1) 動画撮影
缶ビールは、回転台に乗せて、ちょうど1周分の動画を撮影しました。
各銘柄ごと通常缶とロング缶の2種類を撮影しています。
(2) png画像の抽出
撮影した動画から、データセット用のpng画像を抽出するには、SAMを使用しました。
上記のブログで紹介しているように、動画の1フレーム目で、缶ビールの範囲を指定すると、3フレームごとに缶ビールを切り取り、背景が「透過」と「白」の2種類の画像を生成できます。
データセット用に、背景が白のもの、そして、ブラウザ表示用に透過背景のものを使用しました。
MOVファイルからPNGを切り出したコードです。
mov_2_png.py
import os
import datetime
import numpy as np
import torch
import cv2
import glob
import matplotlib.pyplot as plt
from matplotlib import patches
from segment_anything import sam_model_registry, SamPredictor
# マウスで範囲指定する
class BoundingBox:
def __init__(self, image):
self.x1 = -1
self.x2 = -1
self.y1 = -1
self.y2 = -1
self.image = image.copy()
plt.figure()
plt.connect("motion_notify_event", self.motion)
plt.connect("button_press_event", self.press)
plt.connect("button_release_event", self.release)
self.ln_v = plt.axvline(0)
self.ln_h = plt.axhline(0)
plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
plt.show()
# 選択中のカーソル表示
def motion(self, event):
if event.xdata is not None and event.ydata is not None:
self.ln_v.set_xdata(event.xdata)
self.ln_h.set_ydata(event.ydata)
self.x2 = event.xdata.astype("int16")
self.y2 = event.ydata.astype("int16")
if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1:
plt.clf()
plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
ax = plt.gca()
rect = patches.Rectangle(
(self.x1, self.y1),
self.x2 - self.x1,
self.y2 - self.y1,
angle=0.0,
fill=False,
edgecolor="#00FFFF",
)
ax.add_patch(rect)
plt.draw()
# ドラッグ開始位置
def press(self, event):
self.x1 = event.xdata.astype("int16")
self.y1 = event.ydata.astype("int16")
# ドラッグ終了位置、表示終了
def release(self, event):
plt.clf()
plt.close()
def get_area(self):
return int(self.x1), int(self.y1), int(self.x2), int(self.y2)
# SAM
class SegmentAnything:
def __init__(self, device, model_type, sam_checkpoint):
print("init Segment Anything")
self.device = device
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(self.device)
self.predictor = SamPredictor(sam)
@property
def contours(self):
return self._contours
@property
def transparent_image(self):
return self._transparent_image
@property
def white_back_image(self):
return self._white_back_image
@property
def box(self):
return self._box
# マスク取得
def predict(self, frame, input_box):
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
# ノイズ除去
self._mask = self._remove_noise(frame, masks[0])
# 範囲取得
self._box = self._get_box()
# 部分画像取得
self._white_back_image, self._transparent_image = self._get_extract_image(frame)
# マスクの範囲取得
def _get_box(self):
mask_indexes = np.where(self._mask)
y_min = np.min(mask_indexes[0])
y_max = np.max(mask_indexes[0])
x_min = np.min(mask_indexes[1])
x_max = np.max(mask_indexes[1])
return np.array([x_min, y_min, x_max, y_max])
# ノイズ除去
def _remove_noise(self, image, mask):
# 2値画像(白及び黒)を生成する
height, width, _ = image.shape
tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
# マスクによって黒画像の上に白を描画する
tmp_black_image[:] = np.where(
mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image
)
# 輪郭の取得
contours, _ = cv2.findContours(
tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# 最も面積が大きい輪郭を選択
max_contours = max(contours, key=lambda x: cv2.contourArea(x))
# 黒画面に一番大きい輪郭だけ塗りつぶして描画する
black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
black_image = cv2.drawContours(
black_image, [max_contours], -1, color=255, thickness=-1
)
# 輪郭を保存
self._contours, _ = cv2.findContours(
black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# マスクを作り直す
new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_)
new_mask[::] = np.where(black_image[:height, :width] == 0, False, True)
new_mask = np.squeeze(new_mask)
return new_mask
# 部分イメージの取得
def _get_extract_image(self, image):
# boxの範囲でマスクを切り取る
part_of_mask = self._mask[
self._box[1] : self._box[3], self._box[0] : self._box[2]
]
# boxの範囲で元画像を切り取る
copy_image = image.copy() # 個々の食品を切取るためのテンポラリ画像
white_back_image = copy_image[
self._box[1] : self._box[3], self._box[0] : self._box[2]
]
# boxの範囲で白一色の2値画像を作成する
h = self._box[3] - self._box[1]
w = self._box[2] - self._box[0]
white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8)
# マスクによって白画像の上に元画像を描画する
white_back_image[:] = np.where(
part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image
)
transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA)
transparent_image[np.logical_not(part_of_mask), 3] = 0
return white_back_image, transparent_image
class Video:
def __init__(self, filename):
self.cap = cv2.VideoCapture(filename)
if self.cap.isOpened() == False:
print("Video open faild.")
else:
self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 次のフレーム取得
def next_frame(self):
return self.cap.read()
## 総フレーム数
@property
def frame_max(self):
return self._frame_max
def destroy(self):
print("video destroy.")
self.cap.release()
cv2.destroyAllWindows()
# 縦横それぞれ、0.15倍まで広げた、ボックスを取得する
def get_next_input(box):
x1 = box[0]
y1 = box[1]
x2 = box[2]
y2 = box[3]
w = x2 - x1
h = y2 - y1
x_margen = int(w * 0.15)
y_margen = int(h * 0.15)
return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen])
def main():
print("PyTorch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
# Segment Anything
sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth")
output_path = "./png"
files = glob.glob("mov/*.MOV")
for filename in files:
step = 3
start_frame = 0 # 684から乱れる
# start_frame = 525 # 684から乱れる
# filename = "mov/hon_kirin.MOV"
basename = os.path.splitext(os.path.basename(filename))[0]
os.makedirs("{}/{}".format(output_path, basename), exist_ok=True)
video = Video(filename)
try:
print("start")
for i in range(video.frame_max):
ret, frame = video.next_frame()
if ret == False:
continue
# 開始位置まで読み飛ばす
if i < start_frame:
continue
# フレーム省略
if i % step != 0:
continue
# 最初のフレームで、バウンディングボックスを取得する
if i == start_frame:
bounding_box = BoundingBox(frame)
x1, y1, x2, y2 = bounding_box.get_area()
input_box = np.array([x1, y1, x2, y2])
print(
"{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format(
datetime.datetime.now(),
filename,
frame.shape,
start_frame,
input_box,
i + 1,
video.frame_max,
)
)
# マスク生成
sam.predict(frame, input_box)
# 輪郭描画
frame = cv2.drawContours(
frame, sam.contours, -1, color=[255, 255, 0], thickness=6
)
# バウンディングボックス描画
frame = cv2.rectangle(
frame,
pt1=(input_box[0], input_box[1]),
pt2=(input_box[2], input_box[3]),
color=(255, 255, 255),
thickness=2,
lineType=cv2.LINE_4,
)
# データ保存
cv2.imwrite(
"{}/{}/{:09}_t.png".format(output_path, basename, i),
sam.transparent_image,
)
cv2.imwrite(
"{}/{}/{:09}_w.png".format(output_path, basename, i),
sam.white_back_image,
)
# 表示
cv2.imshow("Extract", sam.white_back_image)
cv2.waitKey(1)
cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3))
cv2.waitKey(1)
# 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する
input_box = get_next_input(sam.box)
except KeyboardInterrupt:
video.destroy()
if __name__ == "__main__":
main()
(3) データセット
切り取った缶ビールのpng画像は、YOLOv8の分類モデルでファインチューニングするためのデータセットとして使用できるように、定められた階層に配置します。
回転台で1週回した動画は、概ね600フレームとなり、そこから3フレーム毎に画像を切り取るので、それぞれ約200枚となっています。しかし、動画は、全てピタリ1周分とはなっておらず、いくらか長短が出てしまっています。各クラスのデータ数に偏りが出ないようにと、ここでは、上限を最も短かった動画に合わせて186枚としました。
通常缶 186枚、及び、ロング缶 186枚を合計した、372枚が1つのクラス(銘柄)のデータセットで、それを8対2に分割して、学習用と検証用としました。
$ tree dataset -L 2
dataset
├── train
│ ├── asahi_clear
│ ├── asahi_off
│ │ ├── 000_000000001_w.png
│ │ ├── 000_000000004_w.png
│ │ ├── 000_000000007_w.png
・・・略・・・
│ │ ├── 111_000000237_w.png
│ │ ├── 111_000000243_w.png
│ │ └── 111_000000249_w.png
│ ├── asahi_style_free
│ ├── asahi_the_rich
│ ├── asahi_zeitaku_zero
│ ├── kirin_hon_kirin
│ ├── kirin_nodogoshi_nama
│ ├── kirin_nodogoshi_zero
│ ├── kirin_tanrei
│ ├── sappoeo_mugi_to_hop_black
│ ├── sappoeo_mugi_to_hop_extra_rich
│ ├── sapporo_draft_one
│ ├── sapporo_goku_zero
│ ├── sapporo_gold_star
│ ├── sapporo_gold_star_red
│ ├── sapporo_nama_sibori
│ ├── suntory_kinmugi_blue
│ ├── suntory_kinmugi_gold
│ ├── suntory_kinmugi_red
│ └── suntpry_kinmugi_kohakunoaki
└── val
├── asahi_clear
├── asahi_off
├── asahi_style_free
├── asahi_the_rich
├── asahi_zeitaku_zero
├── kirin_hon_kirin
├── kirin_nodogoshi_nama
├── kirin_nodogoshi_zero
├── kirin_tanrei
├── sappoeo_mugi_to_hop_black
├── sappoeo_mugi_to_hop_extra_rich
├── sapporo_draft_one
├── sapporo_goku_zero
├── sapporo_gold_star
├── sapporo_gold_star_red
├── sapporo_nama_sibori
├── suntory_kinmugi_blue
├── suntory_kinmugi_gold
├── suntory_kinmugi_red
└── suntpry_kinmugi_kohakunoaki
ファイル名の数値部分は、フレーム数です。また、プレフィックスとして000を付与されているものが、通常缶、111が、ロング缶の画像です。
PNG画像をデータセットに配置したコードです。
png_2_dataset.py
import os
import shutil
import glob
input_path = "./png"
dataset_path = "./dataset"
stage_list = ["train", "val"]
max = 186 # あまりデータ数に差が出ないように、多くあるものは、一定数以上を削除する
# pngディレクトリ内のディレクトリを列挙
class_name_list = []
dir_list = glob.glob("{}/*".format(input_path))
for dir in dir_list:
basename = os.path.basename(dir)
class_name_list.append(basename)
print("class_name_list:{}".format(class_name_list))
# dataset作成
os.makedirs(dataset_path, exist_ok=True)
for stage in stage_list:
os.makedirs("{}/{}".format(dataset_path, stage), exist_ok=True)
for class_name in class_name_list:
if class_name.endswith("_000"):
kind = "000" # レギュラー缶
else:
kind = "111" # ロング缶
cls_name = class_name.replace("_000", "").replace("_001", "")
for stage in stage_list:
os.makedirs("{}/{}/{}".format(dataset_path, stage, cls_name), exist_ok=True)
file_list = glob.glob(
"{}/{}/*_w.png".format(input_path, class_name)
) # 白バックのPNGのみ対象とする
file_list.sort()
file_len = len(file_list)
train_len = int(file_len * 0.8)
val_len = file_len - train_len
print("{} {} train:{} val:{}".format(class_name, file_len, train_len, val_len))
val_counter = 0
for i, file in enumerate(file_list):
if i > max:
break
stage = stage_list[i % 2]
if i % 2 == 1:
val_counter += 1
if val_len < val_counter:
stage = stage_list[0] # trainに変更
basename = os.path.basename(file)
input_file = "{}/{}/{}".format(input_path, class_name, basename)
# 出力時は、通常缶とロング缶を同じクラスにまとめる
# output_file = "{}/{}/{}/{}".format(dataset_path, stage, class_name, basename)
output_file = "{}/{}/{}/{}_{}".format(
dataset_path, stage, cls_name, kind, basename
)
shutil.copyfile(input_file, output_file)
(4) データ増幅
データは、下記の手順で約10倍(各クラス 3720枚 train:2976 val:744)に増幅しています。
変換に使用したパラメータです。
# 変換リスト
convertList = [
{"function": saturation, "param": 0.2}, # 彩度
{"function": saturation, "param": 3.0},
{"function": brightness, "param": 0.6}, # 明度
{"function": contrast, "param": 0.8}, # コントラスト
{"function": mosaic, "param": 0.05}, # モザイク
{"function": mosaic, "param": 0.1},
{"function": mosaic, "param": 0.2},
{"function": gaussian, "param": 20.0}, # ガウスノイズ
{"function": noise, "param": 100}, # ごま塩ノイズ
]
この作業で、学習用、検証用共に、1枚のpng画像が、10倍に増幅されることになります。
amplify.py
import os
import glob
import cv2
import numpy as np
from matplotlib import pyplot as plt
# 彩度
def saturation(src, saturation):
# 一旦、BGRをHSVに変換して彩度を変換する
img = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
img[:, :, (1)] = img[:, :, (1)] * saturation
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
# 明度
def brightness(src, brightness):
# 一旦、BGRをHSVに変換して明度を変換する
img = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
img[:, :, (2)] = img[:, :, (2)] * brightness
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
# コントラスト
def contrast(src, alpha):
# 各色相をalphaで演算する
img = alpha * src
return np.clip(img, 0, 255).astype(np.uint8)
# モザイク
def mosaic(src, ratio):
# ratio倍でリサイズして戻す
img = cv2.resize(src, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
return cv2.resize(img, src.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
# ガウスノイズ
def gaussian(src, sigma):
# ランダム値で生成した画像と合成する
row, col, ch = src.shape
mean = 0
gauss = np.random.normal(mean, sigma, (row, col, ch)).astype("u1")
gauss = gauss.reshape(row, col, ch)
return src + gauss
# ごま塩ノイズ
def noise(src, sigma):
img = src.copy()
# 画素数xRGBのノイズを生成
noise = np.random.normal(0, sigma, img.shape)
# ノイズを付加して8bitの範囲にクリップ
noisy_img = img.astype(np.float64) + noise
return np.clip(noisy_img, 0, 255).astype(np.uint8)
def get_image_list(dataset_path, class_name, stage):
image_list = []
files = glob.glob("{}/{}/{}/*".format(dataset_path, stage, class_name))
for file in files:
image_list.append(file)
return image_list
def main():
dataset_path = "./dataset"
# 変換リスト
convertList = [
{"function": saturation, "param": 0.2}, # 彩度
{"function": saturation, "param": 3.0},
{"function": brightness, "param": 0.6}, # 明度
{"function": contrast, "param": 0.8}, # コントラスト
{"function": mosaic, "param": 0.05}, # モザイク
{"function": mosaic, "param": 0.1},
{"function": mosaic, "param": 0.2},
{"function": gaussian, "param": 20.0}, # ガウスノイズ
{"function": noise, "param": 100}, # ごま塩ノイズ
]
# クラス一覧
class_name_list = []
dir_list = glob.glob("{}/train/*".format(dataset_path))
for dir in dir_list:
basename = os.path.basename(dir)
class_name_list.append(basename)
# class_name_list = ["kirin_tanrei"]
print(class_name_list)
stage_list = ["train", "val"]
for class_name in class_name_list:
for stage in stage_list:
image_list = get_image_list(dataset_path, class_name, stage)
print("images: {}件 {} {}".format(len(image_list), class_name, stage))
for image in image_list:
basename = os.path.basename(image)
dirname = os.path.dirname(image)
filename = os.path.splitext(os.path.basename(image))[0]
ext = os.path.splitext(os.path.basename(image))[1]
ext = ext.replace(".", "")
print(
"basename:{} dirname:{} basename_without_ext:{} ext:{}".format(
basename, dirname, filename, ext
)
)
# 変換リストに基づく増幅処理
for i, convert in enumerate(convertList):
# 出力ファイル名
output_image = "{}/{}_{}.{}".format(dirname, filename, i, ext)
print("output_image: {}".format(output_image))
img = cv2.imread(image)
img = convert["function"](img, convert["param"])
# 変換後の画像のエクスポート
cv2.imwrite(output_image, img)
# print("増幅後データ: {}件 ".format(len(outputManifest.split("\n"))))
for class_name in class_name_list:
for stage in stage_list:
image_list = get_image_list(dataset_path, class_name, stage)
print("images: {}件 {} {}".format(len(image_list), class_name, stage))
main()
(5) ファインチューニング
YOLOv8では、下記のように簡単にファインチューニングが出来ます。
model = YOLO("yolov8n-cls.pt") # load a pretrained model
results = model.train(data="./dataset", epochs=10, imgsz=640)
下記は、モデルの性能を確認するため、テスト的に推論してみた結果ですが、非常に高い正解率となっています。
🌝 0.89 asahi_clear_001 asahi_clear
🌝 0.88 asahi_clear_002 asahi_clear
🌝 1.00 asahi_off_001 asahi_off
🌝 1.00 asahi_off_002 asahi_off
🌝 1.00 asahi_style_free_001 asahi_style_free
🌝 1.00 asahi_style_free_002 asahi_style_free
❌ 0.76 asahi_the_rich_001 sappoeo_mugi_to_hop_black
🌝 0.97 asahi_the_rich_002 asahi_the_rich
🌝 0.69 asahi_zeitaku_zero_001 asahi_zeitaku_zero
🌝 0.84 asahi_zeitaku_zero_002 asahi_zeitaku_zero
🌝 0.85 kirin_hon_kirin_001 kirin_hon_kirin
🌝 0.76 kirin_hon_kirin_002 kirin_hon_kirin
🌝 0.64 kirin_nodogoshi_nama_001 kirin_nodogoshi_nama
🌝 0.80 kirin_nodogoshi_nama_002 kirin_nodogoshi_nama
🌝 0.99 kirin_nodogoshi_zero_001 kirin_nodogoshi_zero
🌝 0.86 kirin_nodogoshi_zero_002 kirin_nodogoshi_zero
🌝 0.98 kirin_tanrei_001 kirin_tanrei
🌝 0.96 kirin_tanrei_002 kirin_tanrei
🌝 0.98 sappoeo_mugi_to_hop_extra_rich_001 sappoeo_mugi_to_hop_extra_rich
🌝 0.95 sappoeo_mugi_to_hop_extra_rich_002 sappoeo_mugi_to_hop_extra_rich
🌝 0.98 sapporo_draft_one_001 sapporo_draft_one
🌝 1.00 sapporo_draft_one_002 sapporo_draft_one
🌝 0.71 sapporo_goku_zero_001 sapporo_goku_zero
🌝 0.83 sapporo_goku_zero_002 sapporo_goku_zero
🌝 1.00 sapporo_gold_star_red_001 sapporo_gold_star_red
🌝 0.96 sapporo_gold_star_red_002 sapporo_gold_star_red
🌝 0.87 sapporo_nama_sibori_001 sapporo_nama_sibori
🌝 0.69 sapporo_nama_sibori_002 sapporo_nama_sibori
🌝 1.00 suntory_kinmugi_blue_001 suntory_kinmugi_blue
🌝 1.00 suntory_kinmugi_blue_002 suntory_kinmugi_blue
🌝 1.00 suntory_kinmugi_gold_001 suntory_kinmugi_gold
🌝 0.99 suntory_kinmugi_gold_002 suntory_kinmugi_gold
これは、データセットの画像と、推論対象の画像が、非常に似ているからだと考えています。
下図は一例ですが、上段が、セグメンテーションで切り取られた推論対象の画像、そして、下段が、データセットの画像です。どちらも、概ね横からみた画像で背景が白となっているため、類似性はかなり高いと思います。
5 検出及び判定
SAMによるセグメンテーション、サイズによるフィルタ、そして、分類モデルによる銘柄判定の一連のコードは、下記となります。
beer_detection.py
import torch
import torchvision
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from ultralytics import YOLO
import numpy as np
import cv2
import shutil
import os
import glob
print("PyTorch version: {}".format(torch.__version__))
print("Torchvision version: {}".format(torchvision.__version__))
print("CUDA is available: {}".format(torch.cuda.is_available()))
def create_json(basename, all_contours, all_indexes):
json_file_name = "{}.json".format(basename)
print("create json :{}".format(json_file_name))
str = '{\n\t"data": [\n'
for n, contour in enumerate(all_contours):
str += "\t\t{\n"
str += '\t\t\t"contour": ['
for i, c in enumerate(contour):
if i != 0:
str += ","
str += "[{},{}]".format(c[0][0], c[0][1])
str += "],\n"
str += '\t\t\t"classIndex": {}\n'.format(all_indexes[n])
str += "\t\t}"
if n != len(all_contours) - 1:
str += ","
str += "\n"
str += "\t]\n}\n"
with open(json_file_name, "w") as f:
f.write(str)
def main():
train = "train5"
model_path = "runs/classify/{}/weights/best.pt".format(train)
print("YOLOv8 Image Classification initialize. mode_path:{}".format(model_path))
model = YOLO(model_path)
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
print(
"SAM initialize. checkpoint:{} model_type:{}".format(sam_checkpoint, model_type)
)
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator_ = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.9,
stability_score_thresh=0.96,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Requires open-cv to run post-processing
)
print("job start.")
files = glob.glob("./shop_images/*")
for file in files:
basename = os.path.splitext(os.path.basename(file))[0]
org_image = cv2.imread(file)
copy_image = org_image.copy()
image = cv2.cvtColor(org_image, cv2.COLOR_BGR2RGB)
print("file:{} basename:{} shape:{}".format(file, basename, image.shape))
# 検出されたマスクかを、サイズでフィルタするための値
height, width, _ = image.shape
minH = int(height * 0.14)
maxH = int(height * 0.29)
minW = int(width * 0.04)
maxW = int(width * 0.09)
cv2.rectangle(copy_image, (0, 0), (minW, minH), (255, 255, 0), 2)
cv2.rectangle(copy_image, (0, 0), (maxW, maxH), (255, 255, 0), 2)
# ブログ用(切り取った画像を保存する)
beer_images = "./beer_images"
if os.path.isdir(beer_images):
shutil.rmtree(beer_images)
os.makedirs(beer_images, exist_ok=True)
print("start generate")
masks = mask_generator_.generate(image)
print("masks: {} ".format(len(masks)))
np.random.seed(seed=32)
all_contours = [] # 1画像に含まれるすべての缶ビール分
all_indexes = []
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
mask = mask.astype(np.uint8)
# 2値画像(白及び黒)を生成する
# height, width, _ = image.shape
tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8)
tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
# マスクによって黒画像の上に白を描画する
tmp_black_image[:] = np.where(
mask[:height, :width, np.newaxis] == True,
tmp_white_image,
tmp_black_image,
)
# 輪郭の取得
contours, _ = cv2.findContours(
tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
contours2 = [] # 缶ビール1個分(フィルタ後)
for contour in contours:
# 外接の矩形を求める
x, y, w, h = cv2.boundingRect(contour)
# 外接矩形
if ( # 一定のサイズのみ検出対象にする
minH <= h and h <= maxH and minW <= w and w <= maxW
):
contours2.append(contour)
all_contours.append(contour)
# boxの範囲でマスクを切り取る
part_of_mask = mask[y : y + h, x : x + w]
# boxの範囲で元画像を切り取る
tmp_image = org_image.copy() # 個々の食品を切取るためのテンポラリ画像
white_back_image = tmp_image[y : y + h, x : x + w]
# boxの範囲で白一色の2値画像を作成する
white_image = np.full(np.array([h, w, 3]), 255, dtype=np.uint8)
# マスクによって白画像の上に元画像を描画する
white_back_image[:] = np.where(
part_of_mask[:h, :w, np.newaxis] == False,
white_image,
white_back_image,
)
# ブログ用(切り取った画像を保存する)
cv2.imwrite(
"./{}/{}_{}.png".format(
beer_images, basename, len(all_contours)
),
white_back_image,
)
tmp_image = "./tmp.png"
cv2.imwrite(
tmp_image,
white_back_image,
)
# 各缶ビール画像を分類モデルで判定する
results = model(tmp_image)
names = results[0].names
probs = results[0].probs
i = probs.top5[0]
print(" {:.2f} {}".format(probs.data[i], names[i]))
all_indexes.append(i)
# 輪郭を描画
cv2.drawContours(copy_image, contours2, -1, (255, 0, 0), 2)
# 1画像で検出されたマスク数
print("contours: {}".format(len(all_contours)))
print("all_indexes: {}".format(all_indexes))
# 1画像分のデータファイルとして出力する
create_json(basename, all_contours, all_indexes)
cv2.imshow("Mask", copy_image)
cv2.waitKey(0) # 画像にフォーカスをあてて、キーを押すと、次に進む
main()
そして、ここで出力されるは、ブラウザで表示するためのJSONファイルです。 JSONファイルには、セグメンテーションの座標及び、銘柄を表すクラスインデックスが含まれています。
{
"data": [
{
"contour": [[387,187],[386,188],[378,188],[377,189],[367,189],
・・・略・・・
[426,189],[425,188],[419,188],[418,187]],
"classIndex": 17
},
{
"contour": [[315,186],[314,187],[313,187],[312,188],[311,187],
・・・略・・・
[350,187],[349,188],[348,187],[323,187],[322,186]],
"classIndex": 15
},
・・・略・・・
}
6 ブラウザ表示
上記で作成されたJSONを使用して、Reactで検出結果を表示しています。
検出対象の画像(shop001〜shop004.png)、検出結果のJSON(shop001〜shop004.json)などは、assetsに配置して使用しました。
sample-app % tree -L 2
.
├── node_modules
├── public
├── src
│ ├── App.css
│ ├── App.test.tsx
│ ├── App.tsx
│ ├── assets
│ │ └── data
│ │ ├── beer000.png
│ │ ├── beer001.png
・・・略・・・
│ │ ├── beer018.png
│ │ ├── beer019.png
│ │ ├── names.json
│ │ ├── shop001.json
│ │ ├── shop001.png
・・・略・・・
│ │ ├── shop004.json
│ │ └── shop004.png
│ ├── components
│ ├── index.css
│ ├── index.tsx
│ ├── react-app-env.d.ts
│ ├── reportWebVitals.ts
│ └── setupTests.ts
├── tsconfig.json
└── yarn.lock
メインとなるコンポーネントは、以下のようになっています。
Stage.tsx
import React, { useEffect, useState } from "react";
import targetImage001 from "../assets/data/shop001.png";
import targetData001 from "../assets/data/shop001.json";
import targetImage002 from "../assets/data/shop002.png";
import targetData002 from "../assets/data/shop002.json";
import targetImage003 from "../assets/data/shop003.png";
import targetData003 from "../assets/data/shop003.json";
import targetImage004 from "../assets/data/shop004.png";
import targetData004 from "../assets/data/shop004.json";
import beerImage999 from "../assets/data/beer999.png";
import beerImage000 from "../assets/data/beer000.png";
import beerImage001 from "../assets/data/beer001.png";
import beerImage002 from "../assets/data/beer002.png";
import beerImage003 from "../assets/data/beer003.png";
import beerImage004 from "../assets/data/beer004.png";
import beerImage005 from "../assets/data/beer005.png";
import beerImage006 from "../assets/data/beer006.png";
import beerImage007 from "../assets/data/beer007.png";
import beerImage008 from "../assets/data/beer008.png";
import beerImage009 from "../assets/data/beer009.png";
import beerImage010 from "../assets/data/beer010.png";
import beerImage011 from "../assets/data/beer011.png";
import beerImage012 from "../assets/data/beer012.png";
import beerImage013 from "../assets/data/beer013.png";
import beerImage014 from "../assets/data/beer014.png";
import beerImage015 from "../assets/data/beer015.png";
import beerImage016 from "../assets/data/beer016.png";
import beerImage017 from "../assets/data/beer017.png";
import beerImage018 from "../assets/data/beer018.png";
import beerImage019 from "../assets/data/beer019.png";
import classNames from "../assets/data/names.json";
import "./Stage.css";
import { Button } from "@material-ui/core";
export const Stage: React.FC = () => {
const [info, setInfo] = useState({
x: -1,
y: -1,
mouseIndex: -1,
classIndex: -1,
className: "",
});
const [targetIndex, setTargetIndex] = useState(0);
// shop001 shop002 shop003
const canvasWidth = 1300;
const canvasHeight = 800;
const canvasColor = "#000000";
const canvasId = "canvas";
//const targetIndex = 0;
const targetImageList = [
targetImage001,
targetImage002,
targetImage003,
targetImage004,
];
const targetDataList = [
targetData001,
targetData002,
targetData003,
targetData004,
];
const beerImageList = [
beerImage999,
beerImage000,
beerImage001,
beerImage002,
beerImage003,
beerImage004,
beerImage005,
beerImage006,
beerImage007,
beerImage008,
beerImage009,
beerImage010,
beerImage011,
beerImage012,
beerImage013,
beerImage014,
beerImage015,
beerImage016,
beerImage017,
beerImage018,
beerImage019,
];
const getContext = (): CanvasRenderingContext2D => {
return getCavasElement().getContext("2d")!;
};
const getCavasElement = (): HTMLCanvasElement => {
return document.getElementById(canvasId) as HTMLCanvasElement;
};
const dwawContour = (
ctx: CanvasRenderingContext2D,
contour: number[][],
lineWidth: number,
lineColor: string
) => {
ctx.lineWidth = lineWidth;
ctx.strokeStyle = lineColor;
ctx.beginPath();
for (var c = 0; c < contour.length; c++) {
const point = contour;
if (c === 0) {
ctx.moveTo(point[0], point[1]);
} else if (c === contour.length - 1) {
ctx.closePath();
} else {
ctx.lineTo(point[0], point[1]);
}
}
const point = contour[0];
ctx.moveTo(point[0], point[1]);
for (const point of contour) {
ctx.lineTo(point[0], point[1]);
}
ctx.stroke();
};
// 多角形の範囲内かどうか判定
const isRange = (contour: number[][], x: number, y: number) => {
let minX = 9999;
let maxX = 0;
let minY = 9999;
let maxY = 0;
contour.forEach((point) => {
if (maxX < point[0]) {
maxX = point[0];
}
if (point[0] < minX) {
minX = point[0];
}
if (maxY < point[1]) {
maxY = point[1];
}
if (point[1] < minY) {
minY = point[1];
}
});
if (minX <= x && x <= maxX && minY <= y && y <= maxY) {
return true;
}
return false;
};
const getIndex = (targetData: any, mouseX: number, mouseY: number) => {
let index = -1;
targetData.data.forEach((data: any, i: number) => {
const contour: number[][] = data["contour"] as number[][];
if (isRange(contour, mouseX, mouseY)) {
index = i;
}
});
return index;
};
useEffect(() => {
const ctx = getContext();
// 初期描画
ctx!.fillStyle = canvasColor;
ctx!.fillRect(0, 0, canvasWidth, canvasHeight);
const img = new Image();
img.src = targetImageList[targetIndex];
img.onload = () => {
ctx!.drawImage(img, 0, 0); // 画像は拡大縮小しないで表示する
};
const handleWindowMouseMove = (mouseEvent: MouseEvent) => {
// 画像の位置
const canvasElementRect = getCavasElement().getBoundingClientRect();
// マウスの相対座標
const mouseX = mouseEvent.clientX - canvasElementRect.left;
const mouseY = mouseEvent.clientY - canvasElementRect.top;
// マウス位置が、どのマスクに一致しているかどうか
const mouseIndex = getIndex(targetDataList[targetIndex], mouseX, mouseY);
// マウス位置のオブジェクトのクラスインデックス
const classIndex =
mouseIndex !== -1
? (targetDataList[targetIndex].data[mouseIndex][
"classIndex"
] as number)
: -1;
const className = classIndex !== -1 ? classNames.names[classIndex] : "";
setInfo({
x: mouseX,
y: mouseY,
mouseIndex: mouseIndex,
classIndex: classIndex,
className: className,
});
console.log(
`x:${mouseX} y:${mouseY} mouseIndex:${mouseIndex} classIndex:${classIndex} ${className}`
);
// マウスの位置に基づく描画
const ctx = getContext();
const img = new Image();
img.src = targetImageList[targetIndex];
ctx!.drawImage(img, 0, 0); // 画像は拡大縮小しないで表示する
if (0 < mouseX && 0 < mouseY) {
targetDataList[targetIndex].data.forEach((data, i) => {
const contour: number[][] = data["contour"] as number[][];
let lineColor = "#0000ff";
let lineWidth = 3;
if (i === mouseIndex) {
lineColor = "#aaffff";
lineWidth = 5;
}
dwawContour(ctx, contour, lineWidth, lineColor);
});
}
};
window.addEventListener("mousemove", handleWindowMouseMove);
}, [targetIndex]); // 再描画は、dispLine()に移譲する
return (
<div>
<div className="selector">
{[...Array(4)].map((_, i) => {
return (
<Button
variant="outlined"
color="primary"
onClick={() => {
setTargetIndex(i);
}}
>
Image_{i}
</Button>
);
})}
</div>
<div className="title">
<img src={beerImageList[info.classIndex + 1]} />
<span className="title">{info.className}</span>
</div>
<canvas id={canvasId} width={canvasWidth} height={canvasHeight}></canvas>
<div></div>
<div className="info">
targetIndex={targetIndex} x={info.x} y={info.y} mouseIndex:
{info.mouseIndex} classIndex:
{info.classIndex}
</div>
</div>
);
};
7 最後に
今回は、スーバーマーケットで陳列されている商品(缶ビール)を検出してみました。SAMが登場するまでの手法だと、どうしても、データセット作成にはかなりの工数が必要になったと思います。恐らく、個人できるような作業では無かったのではないか考えています。
下記は、醤油の棚です。まだ、試しておりませんが、今回の手法を応用すれば、下記のような商品棚も、うまく検出できるのでは?と妄想しております。
写真を取らせていただいた店内は、生活協同組合コープさっぽろでした。こちらでは、入り口に 「店内撮影OK」 と掲示されており、お言葉に甘えて撮影させて頂きました。ありがとうございます。
データセットを作成するために買ってきた缶ビールは、今、押し入れに積み上がっています。この後、飲みます。\(^o^)/