機械学習でスーパーマーケットの缶ビールを検出してみました 〜Segment Anythingでセグメンテーションして、YOLOv8の分類モデルで銘柄を判定〜
1 はじめに
CX 事業本部 delivery部の平内(SIN)です。
Meta社による Segment Anything Model(以下、SAM)は、セグメンテーションのための汎用モデルで、ファインチューニングなしで、あらゆる物体がセグメンテーションできます。
今回は、こちらを使用して、スーパーマーケットで冷蔵庫の缶ビールを検出してみました。
最初に、動作確認している様子です。
缶ビールがセグメンテーションされていることと、それぞれの銘柄が判定できていることを確認できると思います。
2 学習済みモデルによる物体検出 (YOLOv8)
YOLOv8などの物体検出モデルでは、配布されている学習済みモデルで、ある程度の物体検出が可能です。
下記は、テーブルに置かれた、缶ビールが、chair、dining tableとともに、cupとして検出されている様子が確認できます。
しかし、このような学習済みモデルでは、店舗に陳列されている缶ビールを検出するようなことは、出来ませんでした。
3 フィルタによる抽出 (Segment Anything Model)
そこで、使用したのが、SAMです。
パラメータにもよりますが、SAMで検出してみると。なんとなく、缶ビールが検出できそうな予感がしてきます。
上記では、缶ビールの検出もできているようですが、缶上のラベルや、値札などもセグメンテーションされていて、ちょっと、欲しいものとは違う感じです。
そこで、目的としている缶ビール以外のマスク情報をフィルタするために、画像上での缶ビールのサイズを調べて、そのサイズに基づいてフィルタしてみました。
まず、缶ビールのサイズの幅及び高さの最大、最小値を決めます。
# 缶ビールは、画像サイズの何%を占めるか height, width, _ = image.shape minH = int(height * 0.14) maxH = int(height * 0.29) minW = int(width * 0.04) maxW = int(width * 0.09)
そして、SAMで検出したすべてのマスクの外接矩形を求め、サイズが範囲内に収まるものだけをフィルタします。
for contour in contours: # 外接の矩形を求める x, y, w, h = cv2.boundingRect(contour) if ( # 一定のサイズのみ検出対象にする minH <= h and h <= maxH and minW <= w and w <= maxW ): all_contours.append(contour)
フィルタしたマスクだけを表示したものです。フィルタ前のマスクは300個程度で、フィルタ後は30個程度となりました。
作業時の様子ですが、計算した、min/maxの矩形を画面の左端に表示して、缶ビールのサイズと比べながら、丁度よいフィルタのサイズを試行錯誤しました。
4 分類モデルによる銘柄判定
セグメンテーションされた缶ビールの銘柄を求めるためには、YOLOv8の分類モデルを使用しました。
分類モデルを作成した手順は、以下のとおりです。
- 動画撮影
- png画像の抽出
- データセット
- データ増幅
- ファインチューニング
(1) 動画撮影
缶ビールは、回転台に乗せて、ちょうど1周分の動画を撮影しました。
各銘柄ごと通常缶とロング缶の2種類を撮影しています。
(2) png画像の抽出
撮影した動画から、データセット用のpng画像を抽出するには、SAMを使用しました。
上記のブログで紹介しているように、動画の1フレーム目で、缶ビールの範囲を指定すると、3フレームごとに缶ビールを切り取り、背景が「透過」と「白」の2種類の画像を生成できます。
データセット用に、背景が白のもの、そして、ブラウザ表示用に透過背景のものを使用しました。
MOVファイルからPNGを切り出したコードです。
mov_2_png.py
import os import datetime import numpy as np import torch import cv2 import glob import matplotlib.pyplot as plt from matplotlib import patches from segment_anything import sam_model_registry, SamPredictor # マウスで範囲指定する class BoundingBox: def __init__(self, image): self.x1 = -1 self.x2 = -1 self.y1 = -1 self.y2 = -1 self.image = image.copy() plt.figure() plt.connect("motion_notify_event", self.motion) plt.connect("button_press_event", self.press) plt.connect("button_release_event", self.release) self.ln_v = plt.axvline(0) self.ln_h = plt.axhline(0) plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)) plt.show() # 選択中のカーソル表示 def motion(self, event): if event.xdata is not None and event.ydata is not None: self.ln_v.set_xdata(event.xdata) self.ln_h.set_ydata(event.ydata) self.x2 = event.xdata.astype("int16") self.y2 = event.ydata.astype("int16") if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1: plt.clf() plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)) ax = plt.gca() rect = patches.Rectangle( (self.x1, self.y1), self.x2 - self.x1, self.y2 - self.y1, angle=0.0, fill=False, edgecolor="#00FFFF", ) ax.add_patch(rect) plt.draw() # ドラッグ開始位置 def press(self, event): self.x1 = event.xdata.astype("int16") self.y1 = event.ydata.astype("int16") # ドラッグ終了位置、表示終了 def release(self, event): plt.clf() plt.close() def get_area(self): return int(self.x1), int(self.y1), int(self.x2), int(self.y2) # SAM class SegmentAnything: def __init__(self, device, model_type, sam_checkpoint): print("init Segment Anything") self.device = device sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(self.device) self.predictor = SamPredictor(sam) @property def contours(self): return self._contours @property def transparent_image(self): return self._transparent_image @property def white_back_image(self): return self._white_back_image @property def box(self): return self._box # マスク取得 def predict(self, frame, input_box): image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) self.predictor.set_image(image) masks, _, _ = self.predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False, ) # ノイズ除去 self._mask = self._remove_noise(frame, masks[0]) # 範囲取得 self._box = self._get_box() # 部分画像取得 self._white_back_image, self._transparent_image = self._get_extract_image(frame) # マスクの範囲取得 def _get_box(self): mask_indexes = np.where(self._mask) y_min = np.min(mask_indexes[0]) y_max = np.max(mask_indexes[0]) x_min = np.min(mask_indexes[1]) x_max = np.max(mask_indexes[1]) return np.array([x_min, y_min, x_max, y_max]) # ノイズ除去 def _remove_noise(self, image, mask): # 2値画像(白及び黒)を生成する height, width, _ = image.shape tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8) tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8) # マスクによって黒画像の上に白を描画する tmp_black_image[:] = np.where( mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image ) # 輪郭の取得 contours, _ = cv2.findContours( tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # 最も面積が大きい輪郭を選択 max_contours = max(contours, key=lambda x: cv2.contourArea(x)) # 黒画面に一番大きい輪郭だけ塗りつぶして描画する black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8) black_image = cv2.drawContours( black_image, [max_contours], -1, color=255, thickness=-1 ) # 輪郭を保存 self._contours, _ = cv2.findContours( black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # マスクを作り直す new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_) new_mask[::] = np.where(black_image[:height, :width] == 0, False, True) new_mask = np.squeeze(new_mask) return new_mask # 部分イメージの取得 def _get_extract_image(self, image): # boxの範囲でマスクを切り取る part_of_mask = self._mask[ self._box[1] : self._box[3], self._box[0] : self._box[2] ] # boxの範囲で元画像を切り取る copy_image = image.copy() # 個々の食品を切取るためのテンポラリ画像 white_back_image = copy_image[ self._box[1] : self._box[3], self._box[0] : self._box[2] ] # boxの範囲で白一色の2値画像を作成する h = self._box[3] - self._box[1] w = self._box[2] - self._box[0] white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8) # マスクによって白画像の上に元画像を描画する white_back_image[:] = np.where( part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image ) transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA) transparent_image[np.logical_not(part_of_mask), 3] = 0 return white_back_image, transparent_image class Video: def __init__(self, filename): self.cap = cv2.VideoCapture(filename) if self.cap.isOpened() == False: print("Video open faild.") else: self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 次のフレーム取得 def next_frame(self): return self.cap.read() ## 総フレーム数 @property def frame_max(self): return self._frame_max def destroy(self): print("video destroy.") self.cap.release() cv2.destroyAllWindows() # 縦横それぞれ、0.15倍まで広げた、ボックスを取得する def get_next_input(box): x1 = box[0] y1 = box[1] x2 = box[2] y2 = box[3] w = x2 - x1 h = y2 - y1 x_margen = int(w * 0.15) y_margen = int(h * 0.15) return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen]) def main(): print("PyTorch version:", torch.__version__) device = "cuda" if torch.cuda.is_available() else "cpu" print("Using {} device".format(device)) # Segment Anything sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth") output_path = "./png" files = glob.glob("mov/*.MOV") for filename in files: step = 3 start_frame = 0 # 684から乱れる # start_frame = 525 # 684から乱れる # filename = "mov/hon_kirin.MOV" basename = os.path.splitext(os.path.basename(filename))[0] os.makedirs("{}/{}".format(output_path, basename), exist_ok=True) video = Video(filename) try: print("start") for i in range(video.frame_max): ret, frame = video.next_frame() if ret == False: continue # 開始位置まで読み飛ばす if i < start_frame: continue # フレーム省略 if i % step != 0: continue # 最初のフレームで、バウンディングボックスを取得する if i == start_frame: bounding_box = BoundingBox(frame) x1, y1, x2, y2 = bounding_box.get_area() input_box = np.array([x1, y1, x2, y2]) print( "{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format( datetime.datetime.now(), filename, frame.shape, start_frame, input_box, i + 1, video.frame_max, ) ) # マスク生成 sam.predict(frame, input_box) # 輪郭描画 frame = cv2.drawContours( frame, sam.contours, -1, color=[255, 255, 0], thickness=6 ) # バウンディングボックス描画 frame = cv2.rectangle( frame, pt1=(input_box[0], input_box[1]), pt2=(input_box[2], input_box[3]), color=(255, 255, 255), thickness=2, lineType=cv2.LINE_4, ) # データ保存 cv2.imwrite( "{}/{}/{:09}_t.png".format(output_path, basename, i), sam.transparent_image, ) cv2.imwrite( "{}/{}/{:09}_w.png".format(output_path, basename, i), sam.white_back_image, ) # 表示 cv2.imshow("Extract", sam.white_back_image) cv2.waitKey(1) cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3)) cv2.waitKey(1) # 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する input_box = get_next_input(sam.box) except KeyboardInterrupt: video.destroy() if __name__ == "__main__": main()
(3) データセット
切り取った缶ビールのpng画像は、YOLOv8の分類モデルでファインチューニングするためのデータセットとして使用できるように、定められた階層に配置します。
回転台で1週回した動画は、概ね600フレームとなり、そこから3フレーム毎に画像を切り取るので、それぞれ約200枚となっています。しかし、動画は、全てピタリ1周分とはなっておらず、いくらか長短が出てしまっています。各クラスのデータ数に偏りが出ないようにと、ここでは、上限を最も短かった動画に合わせて186枚としました。
通常缶 186枚、及び、ロング缶 186枚を合計した、372枚が1つのクラス(銘柄)のデータセットで、それを8対2に分割して、学習用と検証用としました。
$ tree dataset -L 2 dataset ├── train │ ├── asahi_clear │ ├── asahi_off │ │ ├── 000_000000001_w.png │ │ ├── 000_000000004_w.png │ │ ├── 000_000000007_w.png ・・・略・・・ │ │ ├── 111_000000237_w.png │ │ ├── 111_000000243_w.png │ │ └── 111_000000249_w.png │ ├── asahi_style_free │ ├── asahi_the_rich │ ├── asahi_zeitaku_zero │ ├── kirin_hon_kirin │ ├── kirin_nodogoshi_nama │ ├── kirin_nodogoshi_zero │ ├── kirin_tanrei │ ├── sappoeo_mugi_to_hop_black │ ├── sappoeo_mugi_to_hop_extra_rich │ ├── sapporo_draft_one │ ├── sapporo_goku_zero │ ├── sapporo_gold_star │ ├── sapporo_gold_star_red │ ├── sapporo_nama_sibori │ ├── suntory_kinmugi_blue │ ├── suntory_kinmugi_gold │ ├── suntory_kinmugi_red │ └── suntpry_kinmugi_kohakunoaki └── val ├── asahi_clear ├── asahi_off ├── asahi_style_free ├── asahi_the_rich ├── asahi_zeitaku_zero ├── kirin_hon_kirin ├── kirin_nodogoshi_nama ├── kirin_nodogoshi_zero ├── kirin_tanrei ├── sappoeo_mugi_to_hop_black ├── sappoeo_mugi_to_hop_extra_rich ├── sapporo_draft_one ├── sapporo_goku_zero ├── sapporo_gold_star ├── sapporo_gold_star_red ├── sapporo_nama_sibori ├── suntory_kinmugi_blue ├── suntory_kinmugi_gold ├── suntory_kinmugi_red └── suntpry_kinmugi_kohakunoaki
ファイル名の数値部分は、フレーム数です。また、プレフィックスとして000を付与されているものが、通常缶、111が、ロング缶の画像です。
PNG画像をデータセットに配置したコードです。
png_2_dataset.py
import os import shutil import glob input_path = "./png" dataset_path = "./dataset" stage_list = ["train", "val"] max = 186 # あまりデータ数に差が出ないように、多くあるものは、一定数以上を削除する # pngディレクトリ内のディレクトリを列挙 class_name_list = [] dir_list = glob.glob("{}/*".format(input_path)) for dir in dir_list: basename = os.path.basename(dir) class_name_list.append(basename) print("class_name_list:{}".format(class_name_list)) # dataset作成 os.makedirs(dataset_path, exist_ok=True) for stage in stage_list: os.makedirs("{}/{}".format(dataset_path, stage), exist_ok=True) for class_name in class_name_list: if class_name.endswith("_000"): kind = "000" # レギュラー缶 else: kind = "111" # ロング缶 cls_name = class_name.replace("_000", "").replace("_001", "") for stage in stage_list: os.makedirs("{}/{}/{}".format(dataset_path, stage, cls_name), exist_ok=True) file_list = glob.glob( "{}/{}/*_w.png".format(input_path, class_name) ) # 白バックのPNGのみ対象とする file_list.sort() file_len = len(file_list) train_len = int(file_len * 0.8) val_len = file_len - train_len print("{} {} train:{} val:{}".format(class_name, file_len, train_len, val_len)) val_counter = 0 for i, file in enumerate(file_list): if i > max: break stage = stage_list[i % 2] if i % 2 == 1: val_counter += 1 if val_len < val_counter: stage = stage_list[0] # trainに変更 basename = os.path.basename(file) input_file = "{}/{}/{}".format(input_path, class_name, basename) # 出力時は、通常缶とロング缶を同じクラスにまとめる # output_file = "{}/{}/{}/{}".format(dataset_path, stage, class_name, basename) output_file = "{}/{}/{}/{}_{}".format( dataset_path, stage, cls_name, kind, basename ) shutil.copyfile(input_file, output_file)
(4) データ増幅
データは、下記の手順で約10倍(各クラス 3720枚 train:2976 val:744)に増幅しています。
変換に使用したパラメータです。
# 変換リスト convertList = [ {"function": saturation, "param": 0.2}, # 彩度 {"function": saturation, "param": 3.0}, {"function": brightness, "param": 0.6}, # 明度 {"function": contrast, "param": 0.8}, # コントラスト {"function": mosaic, "param": 0.05}, # モザイク {"function": mosaic, "param": 0.1}, {"function": mosaic, "param": 0.2}, {"function": gaussian, "param": 20.0}, # ガウスノイズ {"function": noise, "param": 100}, # ごま塩ノイズ ]
この作業で、学習用、検証用共に、1枚のpng画像が、10倍に増幅されることになります。
amplify.py
import os import glob import cv2 import numpy as np from matplotlib import pyplot as plt # 彩度 def saturation(src, saturation): # 一旦、BGRをHSVに変換して彩度を変換する img = cv2.cvtColor(src, cv2.COLOR_BGR2HSV) img[:, :, (1)] = img[:, :, (1)] * saturation img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) return img # 明度 def brightness(src, brightness): # 一旦、BGRをHSVに変換して明度を変換する img = cv2.cvtColor(src, cv2.COLOR_BGR2HSV) img[:, :, (2)] = img[:, :, (2)] * brightness img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) return img # コントラスト def contrast(src, alpha): # 各色相をalphaで演算する img = alpha * src return np.clip(img, 0, 255).astype(np.uint8) # モザイク def mosaic(src, ratio): # ratio倍でリサイズして戻す img = cv2.resize(src, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST) return cv2.resize(img, src.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) # ガウスノイズ def gaussian(src, sigma): # ランダム値で生成した画像と合成する row, col, ch = src.shape mean = 0 gauss = np.random.normal(mean, sigma, (row, col, ch)).astype("u1") gauss = gauss.reshape(row, col, ch) return src + gauss # ごま塩ノイズ def noise(src, sigma): img = src.copy() # 画素数xRGBのノイズを生成 noise = np.random.normal(0, sigma, img.shape) # ノイズを付加して8bitの範囲にクリップ noisy_img = img.astype(np.float64) + noise return np.clip(noisy_img, 0, 255).astype(np.uint8) def get_image_list(dataset_path, class_name, stage): image_list = [] files = glob.glob("{}/{}/{}/*".format(dataset_path, stage, class_name)) for file in files: image_list.append(file) return image_list def main(): dataset_path = "./dataset" # 変換リスト convertList = [ {"function": saturation, "param": 0.2}, # 彩度 {"function": saturation, "param": 3.0}, {"function": brightness, "param": 0.6}, # 明度 {"function": contrast, "param": 0.8}, # コントラスト {"function": mosaic, "param": 0.05}, # モザイク {"function": mosaic, "param": 0.1}, {"function": mosaic, "param": 0.2}, {"function": gaussian, "param": 20.0}, # ガウスノイズ {"function": noise, "param": 100}, # ごま塩ノイズ ] # クラス一覧 class_name_list = [] dir_list = glob.glob("{}/train/*".format(dataset_path)) for dir in dir_list: basename = os.path.basename(dir) class_name_list.append(basename) # class_name_list = ["kirin_tanrei"] print(class_name_list) stage_list = ["train", "val"] for class_name in class_name_list: for stage in stage_list: image_list = get_image_list(dataset_path, class_name, stage) print("images: {}件 {} {}".format(len(image_list), class_name, stage)) for image in image_list: basename = os.path.basename(image) dirname = os.path.dirname(image) filename = os.path.splitext(os.path.basename(image))[0] ext = os.path.splitext(os.path.basename(image))[1] ext = ext.replace(".", "") print( "basename:{} dirname:{} basename_without_ext:{} ext:{}".format( basename, dirname, filename, ext ) ) # 変換リストに基づく増幅処理 for i, convert in enumerate(convertList): # 出力ファイル名 output_image = "{}/{}_{}.{}".format(dirname, filename, i, ext) print("output_image: {}".format(output_image)) img = cv2.imread(image) img = convert["function"](img, convert["param"]) # 変換後の画像のエクスポート cv2.imwrite(output_image, img) # print("増幅後データ: {}件 ".format(len(outputManifest.split("\n")))) for class_name in class_name_list: for stage in stage_list: image_list = get_image_list(dataset_path, class_name, stage) print("images: {}件 {} {}".format(len(image_list), class_name, stage)) main()
(5) ファインチューニング
YOLOv8では、下記のように簡単にファインチューニングが出来ます。
model = YOLO("yolov8n-cls.pt") # load a pretrained model results = model.train(data="./dataset", epochs=10, imgsz=640)
下記は、モデルの性能を確認するため、テスト的に推論してみた結果ですが、非常に高い正解率となっています。
🌝 0.89 asahi_clear_001 asahi_clear 🌝 0.88 asahi_clear_002 asahi_clear 🌝 1.00 asahi_off_001 asahi_off 🌝 1.00 asahi_off_002 asahi_off 🌝 1.00 asahi_style_free_001 asahi_style_free 🌝 1.00 asahi_style_free_002 asahi_style_free ❌ 0.76 asahi_the_rich_001 sappoeo_mugi_to_hop_black 🌝 0.97 asahi_the_rich_002 asahi_the_rich 🌝 0.69 asahi_zeitaku_zero_001 asahi_zeitaku_zero 🌝 0.84 asahi_zeitaku_zero_002 asahi_zeitaku_zero 🌝 0.85 kirin_hon_kirin_001 kirin_hon_kirin 🌝 0.76 kirin_hon_kirin_002 kirin_hon_kirin 🌝 0.64 kirin_nodogoshi_nama_001 kirin_nodogoshi_nama 🌝 0.80 kirin_nodogoshi_nama_002 kirin_nodogoshi_nama 🌝 0.99 kirin_nodogoshi_zero_001 kirin_nodogoshi_zero 🌝 0.86 kirin_nodogoshi_zero_002 kirin_nodogoshi_zero 🌝 0.98 kirin_tanrei_001 kirin_tanrei 🌝 0.96 kirin_tanrei_002 kirin_tanrei 🌝 0.98 sappoeo_mugi_to_hop_extra_rich_001 sappoeo_mugi_to_hop_extra_rich 🌝 0.95 sappoeo_mugi_to_hop_extra_rich_002 sappoeo_mugi_to_hop_extra_rich 🌝 0.98 sapporo_draft_one_001 sapporo_draft_one 🌝 1.00 sapporo_draft_one_002 sapporo_draft_one 🌝 0.71 sapporo_goku_zero_001 sapporo_goku_zero 🌝 0.83 sapporo_goku_zero_002 sapporo_goku_zero 🌝 1.00 sapporo_gold_star_red_001 sapporo_gold_star_red 🌝 0.96 sapporo_gold_star_red_002 sapporo_gold_star_red 🌝 0.87 sapporo_nama_sibori_001 sapporo_nama_sibori 🌝 0.69 sapporo_nama_sibori_002 sapporo_nama_sibori 🌝 1.00 suntory_kinmugi_blue_001 suntory_kinmugi_blue 🌝 1.00 suntory_kinmugi_blue_002 suntory_kinmugi_blue 🌝 1.00 suntory_kinmugi_gold_001 suntory_kinmugi_gold 🌝 0.99 suntory_kinmugi_gold_002 suntory_kinmugi_gold
これは、データセットの画像と、推論対象の画像が、非常に似ているからだと考えています。
下図は一例ですが、上段が、セグメンテーションで切り取られた推論対象の画像、そして、下段が、データセットの画像です。どちらも、概ね横からみた画像で背景が白となっているため、類似性はかなり高いと思います。
5 検出及び判定
SAMによるセグメンテーション、サイズによるフィルタ、そして、分類モデルによる銘柄判定の一連のコードは、下記となります。
beer_detection.py
import torch import torchvision import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor from ultralytics import YOLO import numpy as np import cv2 import shutil import os import glob print("PyTorch version: {}".format(torch.__version__)) print("Torchvision version: {}".format(torchvision.__version__)) print("CUDA is available: {}".format(torch.cuda.is_available())) def create_json(basename, all_contours, all_indexes): json_file_name = "{}.json".format(basename) print("create json :{}".format(json_file_name)) str = '{\n\t"data": [\n' for n, contour in enumerate(all_contours): str += "\t\t{\n" str += '\t\t\t"contour": [' for i, c in enumerate(contour): if i != 0: str += "," str += "[{},{}]".format(c[0][0], c[0][1]) str += "],\n" str += '\t\t\t"classIndex": {}\n'.format(all_indexes[n]) str += "\t\t}" if n != len(all_contours) - 1: str += "," str += "\n" str += "\t]\n}\n" with open(json_file_name, "w") as f: f.write(str) def main(): train = "train5" model_path = "runs/classify/{}/weights/best.pt".format(train) print("YOLOv8 Image Classification initialize. mode_path:{}".format(model_path)) model = YOLO(model_path) sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" print( "SAM initialize. checkpoint:{} model_type:{}".format(sam_checkpoint, model_type) ) sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) mask_generator_ = SamAutomaticMaskGenerator( model=sam, points_per_side=32, pred_iou_thresh=0.9, stability_score_thresh=0.96, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, # Requires open-cv to run post-processing ) print("job start.") files = glob.glob("./shop_images/*") for file in files: basename = os.path.splitext(os.path.basename(file))[0] org_image = cv2.imread(file) copy_image = org_image.copy() image = cv2.cvtColor(org_image, cv2.COLOR_BGR2RGB) print("file:{} basename:{} shape:{}".format(file, basename, image.shape)) # 検出されたマスクかを、サイズでフィルタするための値 height, width, _ = image.shape minH = int(height * 0.14) maxH = int(height * 0.29) minW = int(width * 0.04) maxW = int(width * 0.09) cv2.rectangle(copy_image, (0, 0), (minW, minH), (255, 255, 0), 2) cv2.rectangle(copy_image, (0, 0), (maxW, maxH), (255, 255, 0), 2) # ブログ用(切り取った画像を保存する) beer_images = "./beer_images" if os.path.isdir(beer_images): shutil.rmtree(beer_images) os.makedirs(beer_images, exist_ok=True) print("start generate") masks = mask_generator_.generate(image) print("masks: {} ".format(len(masks))) np.random.seed(seed=32) all_contours = [] # 1画像に含まれるすべての缶ビール分 all_indexes = [] for i, mask_data in enumerate(masks): mask = mask_data["segmentation"] mask = mask.astype(np.uint8) # 2値画像(白及び黒)を生成する # height, width, _ = image.shape tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8) tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8) # マスクによって黒画像の上に白を描画する tmp_black_image[:] = np.where( mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image, ) # 輪郭の取得 contours, _ = cv2.findContours( tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) contours2 = [] # 缶ビール1個分(フィルタ後) for contour in contours: # 外接の矩形を求める x, y, w, h = cv2.boundingRect(contour) # 外接矩形 if ( # 一定のサイズのみ検出対象にする minH <= h and h <= maxH and minW <= w and w <= maxW ): contours2.append(contour) all_contours.append(contour) # boxの範囲でマスクを切り取る part_of_mask = mask[y : y + h, x : x + w] # boxの範囲で元画像を切り取る tmp_image = org_image.copy() # 個々の食品を切取るためのテンポラリ画像 white_back_image = tmp_image[y : y + h, x : x + w] # boxの範囲で白一色の2値画像を作成する white_image = np.full(np.array([h, w, 3]), 255, dtype=np.uint8) # マスクによって白画像の上に元画像を描画する white_back_image[:] = np.where( part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image, ) # ブログ用(切り取った画像を保存する) cv2.imwrite( "./{}/{}_{}.png".format( beer_images, basename, len(all_contours) ), white_back_image, ) tmp_image = "./tmp.png" cv2.imwrite( tmp_image, white_back_image, ) # 各缶ビール画像を分類モデルで判定する results = model(tmp_image) names = results[0].names probs = results[0].probs i = probs.top5[0] print(" {:.2f} {}".format(probs.data[i], names[i])) all_indexes.append(i) # 輪郭を描画 cv2.drawContours(copy_image, contours2, -1, (255, 0, 0), 2) # 1画像で検出されたマスク数 print("contours: {}".format(len(all_contours))) print("all_indexes: {}".format(all_indexes)) # 1画像分のデータファイルとして出力する create_json(basename, all_contours, all_indexes) cv2.imshow("Mask", copy_image) cv2.waitKey(0) # 画像にフォーカスをあてて、キーを押すと、次に進む main()
そして、ここで出力されるは、ブラウザで表示するためのJSONファイルです。 JSONファイルには、セグメンテーションの座標及び、銘柄を表すクラスインデックスが含まれています。
{ "data": [ { "contour": [[387,187],[386,188],[378,188],[377,189],[367,189], ・・・略・・・ [426,189],[425,188],[419,188],[418,187]], "classIndex": 17 }, { "contour": [[315,186],[314,187],[313,187],[312,188],[311,187], ・・・略・・・ [350,187],[349,188],[348,187],[323,187],[322,186]], "classIndex": 15 }, ・・・略・・・ }
6 ブラウザ表示
上記で作成されたJSONを使用して、Reactで検出結果を表示しています。
検出対象の画像(shop001〜shop004.png)、検出結果のJSON(shop001〜shop004.json)などは、assetsに配置して使用しました。
sample-app % tree -L 2 . ├── node_modules ├── public ├── src │ ├── App.css │ ├── App.test.tsx │ ├── App.tsx │ ├── assets │ │ └── data │ │ ├── beer000.png │ │ ├── beer001.png ・・・略・・・ │ │ ├── beer018.png │ │ ├── beer019.png │ │ ├── names.json │ │ ├── shop001.json │ │ ├── shop001.png ・・・略・・・ │ │ ├── shop004.json │ │ └── shop004.png │ ├── components │ ├── index.css │ ├── index.tsx │ ├── react-app-env.d.ts │ ├── reportWebVitals.ts │ └── setupTests.ts ├── tsconfig.json └── yarn.lock
メインとなるコンポーネントは、以下のようになっています。
Stage.tsx
import React, { useEffect, useState } from "react"; import targetImage001 from "../assets/data/shop001.png"; import targetData001 from "../assets/data/shop001.json"; import targetImage002 from "../assets/data/shop002.png"; import targetData002 from "../assets/data/shop002.json"; import targetImage003 from "../assets/data/shop003.png"; import targetData003 from "../assets/data/shop003.json"; import targetImage004 from "../assets/data/shop004.png"; import targetData004 from "../assets/data/shop004.json"; import beerImage999 from "../assets/data/beer999.png"; import beerImage000 from "../assets/data/beer000.png"; import beerImage001 from "../assets/data/beer001.png"; import beerImage002 from "../assets/data/beer002.png"; import beerImage003 from "../assets/data/beer003.png"; import beerImage004 from "../assets/data/beer004.png"; import beerImage005 from "../assets/data/beer005.png"; import beerImage006 from "../assets/data/beer006.png"; import beerImage007 from "../assets/data/beer007.png"; import beerImage008 from "../assets/data/beer008.png"; import beerImage009 from "../assets/data/beer009.png"; import beerImage010 from "../assets/data/beer010.png"; import beerImage011 from "../assets/data/beer011.png"; import beerImage012 from "../assets/data/beer012.png"; import beerImage013 from "../assets/data/beer013.png"; import beerImage014 from "../assets/data/beer014.png"; import beerImage015 from "../assets/data/beer015.png"; import beerImage016 from "../assets/data/beer016.png"; import beerImage017 from "../assets/data/beer017.png"; import beerImage018 from "../assets/data/beer018.png"; import beerImage019 from "../assets/data/beer019.png"; import classNames from "../assets/data/names.json"; import "./Stage.css"; import { Button } from "@material-ui/core"; export const Stage: React.FC = () => { const [info, setInfo] = useState({ x: -1, y: -1, mouseIndex: -1, classIndex: -1, className: "", }); const [targetIndex, setTargetIndex] = useState(0); // shop001 shop002 shop003 const canvasWidth = 1300; const canvasHeight = 800; const canvasColor = "#000000"; const canvasId = "canvas"; //const targetIndex = 0; const targetImageList = [ targetImage001, targetImage002, targetImage003, targetImage004, ]; const targetDataList = [ targetData001, targetData002, targetData003, targetData004, ]; const beerImageList = [ beerImage999, beerImage000, beerImage001, beerImage002, beerImage003, beerImage004, beerImage005, beerImage006, beerImage007, beerImage008, beerImage009, beerImage010, beerImage011, beerImage012, beerImage013, beerImage014, beerImage015, beerImage016, beerImage017, beerImage018, beerImage019, ]; const getContext = (): CanvasRenderingContext2D => { return getCavasElement().getContext("2d")!; }; const getCavasElement = (): HTMLCanvasElement => { return document.getElementById(canvasId) as HTMLCanvasElement; }; const dwawContour = ( ctx: CanvasRenderingContext2D, contour: number[][], lineWidth: number, lineColor: string ) => { ctx.lineWidth = lineWidth; ctx.strokeStyle = lineColor; ctx.beginPath(); for (var c = 0; c < contour.length; c++) { const point = contour; if (c === 0) { ctx.moveTo(point[0], point[1]); } else if (c === contour.length - 1) { ctx.closePath(); } else { ctx.lineTo(point[0], point[1]); } } const point = contour[0]; ctx.moveTo(point[0], point[1]); for (const point of contour) { ctx.lineTo(point[0], point[1]); } ctx.stroke(); }; // 多角形の範囲内かどうか判定 const isRange = (contour: number[][], x: number, y: number) => { let minX = 9999; let maxX = 0; let minY = 9999; let maxY = 0; contour.forEach((point) => { if (maxX < point[0]) { maxX = point[0]; } if (point[0] < minX) { minX = point[0]; } if (maxY < point[1]) { maxY = point[1]; } if (point[1] < minY) { minY = point[1]; } }); if (minX <= x && x <= maxX && minY <= y && y <= maxY) { return true; } return false; }; const getIndex = (targetData: any, mouseX: number, mouseY: number) => { let index = -1; targetData.data.forEach((data: any, i: number) => { const contour: number[][] = data["contour"] as number[][]; if (isRange(contour, mouseX, mouseY)) { index = i; } }); return index; }; useEffect(() => { const ctx = getContext(); // 初期描画 ctx!.fillStyle = canvasColor; ctx!.fillRect(0, 0, canvasWidth, canvasHeight); const img = new Image(); img.src = targetImageList[targetIndex]; img.onload = () => { ctx!.drawImage(img, 0, 0); // 画像は拡大縮小しないで表示する }; const handleWindowMouseMove = (mouseEvent: MouseEvent) => { // 画像の位置 const canvasElementRect = getCavasElement().getBoundingClientRect(); // マウスの相対座標 const mouseX = mouseEvent.clientX - canvasElementRect.left; const mouseY = mouseEvent.clientY - canvasElementRect.top; // マウス位置が、どのマスクに一致しているかどうか const mouseIndex = getIndex(targetDataList[targetIndex], mouseX, mouseY); // マウス位置のオブジェクトのクラスインデックス const classIndex = mouseIndex !== -1 ? (targetDataList[targetIndex].data[mouseIndex][ "classIndex" ] as number) : -1; const className = classIndex !== -1 ? classNames.names[classIndex] : ""; setInfo({ x: mouseX, y: mouseY, mouseIndex: mouseIndex, classIndex: classIndex, className: className, }); console.log( `x:${mouseX} y:${mouseY} mouseIndex:${mouseIndex} classIndex:${classIndex} ${className}` ); // マウスの位置に基づく描画 const ctx = getContext(); const img = new Image(); img.src = targetImageList[targetIndex]; ctx!.drawImage(img, 0, 0); // 画像は拡大縮小しないで表示する if (0 < mouseX && 0 < mouseY) { targetDataList[targetIndex].data.forEach((data, i) => { const contour: number[][] = data["contour"] as number[][]; let lineColor = "#0000ff"; let lineWidth = 3; if (i === mouseIndex) { lineColor = "#aaffff"; lineWidth = 5; } dwawContour(ctx, contour, lineWidth, lineColor); }); } }; window.addEventListener("mousemove", handleWindowMouseMove); }, [targetIndex]); // 再描画は、dispLine()に移譲する return ( <div> <div className="selector"> {[...Array(4)].map((_, i) => { return ( <Button variant="outlined" color="primary" onClick={() => { setTargetIndex(i); }} > Image_{i} </Button> ); })} </div> <div className="title"> <img src={beerImageList[info.classIndex + 1]} /> <span className="title">{info.className}</span> </div> <canvas id={canvasId} width={canvasWidth} height={canvasHeight}></canvas> <div></div> <div className="info"> targetIndex={targetIndex} x={info.x} y={info.y} mouseIndex: {info.mouseIndex} classIndex: {info.classIndex} </div> </div> ); };
7 最後に
今回は、スーバーマーケットで陳列されている商品(缶ビール)を検出してみました。SAMが登場するまでの手法だと、どうしても、データセット作成にはかなりの工数が必要になったと思います。恐らく、個人できるような作業では無かったのではないか考えています。
下記は、醤油の棚です。まだ、試しておりませんが、今回の手法を応用すれば、下記のような商品棚も、うまく検出できるのでは?と妄想しております。
写真を取らせていただいた店内は、生活協同組合コープさっぽろでした。こちらでは、入り口に 「店内撮影OK」 と掲示されており、お言葉に甘えて撮影させて頂きました。ありがとうございます。
データセットを作成するために買ってきた缶ビールは、今、押し入れに積み上がっています。この後、飲みます。\(^o^)/