[YOLOv8 Instance Segmentation] 「きのこの山」に潜伏する「たけのこの里」を機械学習で見つけてみました 〜データセットは、Segment Anything Modelで自動的に生成されています〜
1 はじめに
CX事業本部製造ビジネステクノロジー部の平内(SIN)です。
YOLOv8は、イメージ分類・物体検出・セグメンテーション・骨格検出などに対応していますが、今回は、セグメンテーションモデルをファインチューニングして、「きのこの山」と「たけのこの里」を検出してみました。
最初に、動作している様子をご確認下さい。比較的に精度高くセグメンテーション出来ていると思います。
2 データセット作成
セグメンテーションモデルを学習する為のデータは、下記のように、対象物の輪郭座標が必要であり、これを大量に作成するのは、結構な膨大な作業量になってしまいます。
そこで、この作業は、Segment Anything Model(SAM)を使用し、対象物を撮影した動画から自動的に作成してみました。 作成されたデータセットは、3,000枚の画像、27,000個のアノテーションとなりました。
データセット作成の工程は、概ね以下のとおりです。
- (1) 動画の撮影
- (2) オブジェクトの切出し
- (3) 背景と合成してデータセット作成
- (4) YOLOv8形式への変換
(1) 動画の撮影
撮影は、回転台の上に載せてスマフォで撮っています。
下に、紺色の布を敷いているのは、影が目立つと、この後、Segment Anything Modelでオブジェクトを切り抜く際に、影も対象物とご判断されてしまうことを避けるためです。
撮影された動画は、以下のような感じです。
(2) オブジェクトの切出し
オブジェクト切出しのプログラムでは、SAM(Segment Anything Model)を使用して動画の各フレームから対象オブジェクトを切出しています。
対象物の検出範囲は、動画の1フレーム目で、設定するようになっています。
最初のフレームで、対象物を検出できたら、その物体の一回り大きな枠を、次のフレームの検出範囲としているため対象物が、少々ずれても、追従させながら切り出すことができるようになっています。
参考
切り出されたデータは、以下のとおりとなります。
- xxxxxxxxxx_w.png 背景が白の切り取り画像
- xxxxxxxxxx_t.png 背景が透過の切り取り画像(今回、未使用)
- xxxxxxxxxx_l.txt 切り取り画像の左上を原点(0,0)とした輪郭の座標
dataset └── output_png └── KINOKO | ├── 000000000_l.txt | ├── 000000000_t.png | ├── 000000000_w.png ... | ├── 000000102_l.txt | ├── 000000102_t.png | └── 000000102_w.png └── TAKENOKO ├── 000000000_l.txt ├── 000000000_t.png ├── 000000000_w.png ... ├── 000000102_l.txt ├── 000000102_t.png └── 000000102_w.png
以下の動画は、オブジェクトの切出し作業をしているところです。
そして、オブジェクトの切出しのプログラムです。
mp4_2_png.py
# SAMによる mp4ファイルからのターゲット摘出 import os import datetime import numpy as np import torch import cv2 import matplotlib.pyplot as plt from matplotlib import patches from segment_anything import sam_model_registry, SamPredictor # マウスで範囲指定する class BoundingBox: def __init__(self, image): self.x1 = -1 self.x2 = -1 self.y1 = -1 self.y2 = -1 self.image = image.copy() plt.figure() plt.connect("motion_notify_event", self.motion) plt.connect("button_press_event", self.press) plt.connect("button_release_event", self.release) self.ln_v = plt.axvline(0) self.ln_h = plt.axhline(0) plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)) plt.show() # 選択中のカーソル表示 def motion(self, event): if event.xdata is not None and event.ydata is not None: self.ln_v.set_xdata(event.xdata) self.ln_h.set_ydata(event.ydata) self.x2 = event.xdata.astype("int16") self.y2 = event.ydata.astype("int16") if self.x1 != -1 and self.x2 != -1 and self.y1 != -1 and self.y2 != -1: plt.clf() plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)) ax = plt.gca() rect = patches.Rectangle( (self.x1, self.y1), self.x2 - self.x1, self.y2 - self.y1, angle=0.0, fill=False, edgecolor="#00FFFF", ) ax.add_patch(rect) plt.draw() # ドラッグ開始位置 def press(self, event): self.x1 = event.xdata.astype("int16") self.y1 = event.ydata.astype("int16") # ドラッグ終了位置、表示終了 def release(self, event): plt.clf() plt.close() def get_area(self): return int(self.x1), int(self.y1), int(self.x2), int(self.y2) # SAM class SegmentAnything: def __init__(self, device, model_type, sam_checkpoint): print("init Segment Anything") self.device = device sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(self.device) self.predictor = SamPredictor(sam) @property def contours(self): return self._contours @property def transparent_image(self): return self._transparent_image @property def white_back_image(self): return self._white_back_image @property def box(self): return self._box @property def segment(self): return self._segment # マスク取得 def predict(self, frame, input_box): image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) self.predictor.set_image(image) masks, _, _ = self.predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False, ) # ノイズ除去 self._mask = self._remove_noise(frame, masks[0]) # 範囲取得 self._box = self._get_box() # 部分画像取得 self._white_back_image, self._transparent_image = self._get_extract_image(frame) # セグメンテーション取得 self._segment = self._get_segment() # マスクの範囲取得 def _get_box(self): mask_indexes = np.where(self._mask) y_min = np.min(mask_indexes[0]) y_max = np.max(mask_indexes[0]) x_min = np.min(mask_indexes[1]) x_max = np.max(mask_indexes[1]) return np.array([x_min, y_min, x_max, y_max]) # ノイズ除去 def _remove_noise(self, image, mask): # 2値画像(白及び黒)を生成する height, width, _ = image.shape tmp_black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8) tmp_white_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8) # マスクによって黒画像の上に白を描画する tmp_black_image[:] = np.where( mask[:height, :width, np.newaxis] == True, tmp_white_image, tmp_black_image ) # 輪郭の取得 contours, _ = cv2.findContours( tmp_black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # 最も面積が大きい輪郭を選択 max_contours = max(contours, key=lambda x: cv2.contourArea(x)) # 黒画面に一番大きい輪郭だけ塗りつぶして描画する black_image = np.full(np.array([height, width, 1]), 0, dtype=np.uint8) black_image = cv2.drawContours( black_image, [max_contours], -1, color=255, thickness=-1 ) # 輪郭を保存 self._contours, _ = cv2.findContours( black_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # マスクを作り直す new_mask = np.full(np.array([height, width, 1]), False, dtype=np.bool_) new_mask[::] = np.where(black_image[:height, :width] == 0, False, True) new_mask = np.squeeze(new_mask) return new_mask # 部分イメージの取得 def _get_extract_image(self, image): # boxの範囲でマスクを切り取る part_of_mask = self._mask[ self._box[1] : self._box[3], self._box[0] : self._box[2] ] # boxの範囲で元画像を切り取る copy_image = image.copy() # 個々の食品を切取るためのテンポラリ画像 white_back_image = copy_image[ self._box[1] : self._box[3], self._box[0] : self._box[2] ] # boxの範囲で白一色の2値画像を作成する h = self._box[3] - self._box[1] w = self._box[2] - self._box[0] white_image = np.full(np.array([h, w, 1]), 255, dtype=np.uint8) # マスクによって白画像の上に元画像を描画する white_back_image[:] = np.where( part_of_mask[:h, :w, np.newaxis] == False, white_image, white_back_image ) transparent_image = cv2.cvtColor(white_back_image, cv2.COLOR_BGR2BGRA) transparent_image[np.logical_not(part_of_mask), 3] = 0 return white_back_image, transparent_image def _get_segment(self): box = self._box.tolist() contours = self._contours[0].tolist() width = box[2] - box[0] height = box[3] - box[1] # print("box type:{} len:{}".format(type(box), len(box))) # print("v type:{} len:{}".format(type(contours), len(contours))) # print("box {},{},{},{}".format(box[0], box[1], box[2], box[3])) # print("width:{} height:{}".format(width, height)) text = "" for i, data in enumerate(contours): d = data[0] x = d[0] - box[0] y = d[1] - box[1] text += "{},{},".format(x, y) return text class Video: def __init__(self, filename): self.cap = cv2.VideoCapture(filename) if self.cap.isOpened() == False: print("Video open faild.") else: self._frame_max = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 次のフレーム取得 def next_frame(self): return self.cap.read() ## 総フレーム数 @property def frame_max(self): return self._frame_max def destroy(self): print("video destroy.") self.cap.release() cv2.destroyAllWindows() # 縦横それぞれ、0.15倍まで広げた、ボックスを取得する def get_next_input(box): x1 = box[0] y1 = box[1] x2 = box[2] y2 = box[3] w = x2 - x1 h = y2 - y1 x_margen = int(w * 0.15) y_margen = int(h * 0.15) return np.array([x1 - x_margen, y1 - y_margen, x2 + x_margen, y2 + y_margen]) def main(): print("PyTorch version:", torch.__version__) device = "cuda" if torch.cuda.is_available() else "cpu" print("Using {} device".format(device)) step = 3 start_frame = 0 filename = "KINOKO.mp4" # filename = "TAKENOKO.mp4" output_path = "./dataset/output_png" basename = os.path.splitext(os.path.basename(filename))[0] os.makedirs("{}/{}".format(output_path, basename), exist_ok=True) video = Video(filename) # Segment Anything sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth") try: print("start") for i in range(video.frame_max): ret, frame = video.next_frame() if ret == False: continue # 開始位置まで読み飛ばす if i < start_frame: continue # フレーム省略 if i % step != 0: continue # 最初のフレームで、バウンディングボックスを取得する if i == start_frame: bounding_box = BoundingBox(frame) x1, y1, x2, y2 = bounding_box.get_area() input_box = np.array([x1, y1, x2, y2]) print( "{} filename:{} shape:{} start_frame:{} input_box:{} frams:{}/{}".format( datetime.datetime.now(), filename, frame.shape, start_frame, input_box, i + 1, video.frame_max, ) ) # マスク生成 sam.predict(frame, input_box) # 輪郭描画 frame = cv2.drawContours( frame, sam.contours, -1, color=[255, 255, 0], thickness=6 ) # バウンディングボックス描画 frame = cv2.rectangle( frame, pt1=(input_box[0], input_box[1]), pt2=(input_box[2], input_box[3]), color=(255, 255, 255), thickness=2, lineType=cv2.LINE_4, ) # データ保存 cv2.imwrite( "{}/{}/{:09}_t.png".format(output_path, basename, i), sam.transparent_image, ) cv2.imwrite( "{}/{}/{:09}_w.png".format(output_path, basename, i), sam.white_back_image, ) # 座標保存 with open( "{}/{}/{:09}_l.txt".format(output_path, basename, i), mode="w" ) as f: f.write(sam.segment) # 表示 cv2.imshow("Extract", sam.white_back_image) cv2.waitKey(1) cv2.imshow("Video", cv2.resize(frame, None, fx=0.3, fy=0.3)) cv2.waitKey(1) # 次のFrameで、検出範囲よりひと回り大きい範囲をBOX指定する input_box = get_next_input(sam.box) except KeyboardInterrupt: video.destroy() if __name__ == "__main__": main()
(3) 背景と合成してデータセット作成
切出した背景透過のデータを背景に合成して、データセット用の画像を生成します。
背景への合成は、ランダム位置で、ランダムなサイズで行われます。
なお、背景に対する、オブジェクトのサイズが一定の範囲に収まるように、変換は下記を基準としています。
- 背景画像 640 * 480
- 切り出し画像 横幅が200になるようリサイズしてから、0.3〜1倍に変換する
切出し画像を貼り付けた後は、相対座標のデータもリサイズに合わせて変換し、最終的に背景画像の幅を基準に0〜1で正規化され、画像と同名の.txt として出力しています。
この時点で.txtは、YOLOのデータセット形式となっています。
(クラスIDが決定されるのは、この時点となります)
./dataset/output_merge ├── 00000.png ├── 00000.txt ├── 00001.png ├── 00001.txt ・・・ ├── 02999.png ├── 02999.txt ├── 03000.png └── 03000.txt
こちらは、背景と合成してデータセット作成している様子です。
合成に使用したプログラムです。
png_2_merge_data.py
""" SAMで切り出した画像データと背景を組み合わせてデータセットを作成する """ import json import glob import random import os import shutil import math import numpy as np import cv2 from PIL import Image MAX = 3000 # 生成する画像数 CLASS_NAME = ["KINOKO", "TAKENOKO"] COLORS = [(0, 0, 175), (175, 0, 0)] BACKGROUND_IMAGE_PATH = "./dataset/background_images" TARGET_IMAGE_PATH = "./dataset/output_png" OUTPUT_PATH = "./dataset/output_merge" BASE_WIDTH = 200 # 商品の基本サイズは、背景画像とのバランスより、横幅を200を基準とする BACK_WIDTH = 640 # 背景画像ファイルのサイズを合わせる必要がある BACK_HEIGHT = 480 # 背景画像ファイルのサイズを合わせる必要がある # 背景画像取得クラス class Background: def __init__(self, backPath): self.__backPath = backPath def get(self): imagePath = random.choice(glob.glob(self.__backPath + "/*.jpg")) return cv2.imread(imagePath, cv2.IMREAD_UNCHANGED) # 検出対象取得クラス (base_widthで指定された横幅を基準にリサイズされる) class Target: def __init__(self, target_path, base_width, class_name): self.__target_path = target_path self.__base_width = base_width self.__class_name = class_name def get(self, class_id): # 切り出し画像 class_name = self.__class_name[class_id] image_path = random.choice( glob.glob(self.__target_path + "/" + class_name + "/*_t.png") ) target_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # 基準(横)サイズに基づきリサイズされる h, w, _ = target_image.shape aspect = h / w target_image = cv2.resize( target_image, (int(self.__base_width * aspect), self.__base_width) ) bai = self.__base_width / w # labelも基準サイズへのリサイズに併せて、変換する( x * bai ) target_label = "" label_text = "" label_path = image_path.replace("_t.png", "_l.txt") with open(label_path, encoding="utf-8") as f: label_text = f.read() orign_label_list = label_text.split(",") for label in orign_label_list: if label != "": target_label += "{},".format(float(label) * bai) return target_image, target_label # 変換クラス class Transformer: def __init__(self, width, height): self.__width = width self.__height = height self.__min_scale = 0.3 self.__max_scale = 1 def warp(self, target_image): # サイズ変更 target_image, scale = self.__resize(target_image) # 配置位置決定 h, w, _ = target_image.shape left = random.randint(0, self.__width - w) top = random.randint(0, self.__height - h) rect = ((left, top), (left + w, top + h)) # 背景面との合成 new_image = self.__synthesize(target_image, left, top) return (new_image, rect, scale) def __resize(self, img): scale = random.uniform(self.__min_scale, self.__max_scale) w, h, _ = img.shape return cv2.resize(img, (int(w * scale), int(h * scale))), scale def __rote(self, target_image, angle): h, w, _ = target_image.shape rate = h / w scale = 1 if rate < 0.9 or 1.1 < rate: scale = 0.9 elif rate < 0.8 or 1.2 < rate: scale = 0.6 center = (int(w / 2), int(h / 2)) trans = cv2.getRotationMatrix2D(center, angle, scale) return cv2.warpAffine(target_image, trans, (w, h)) def __synthesize(self, target_image, left, top): background_image = np.zeros((self.__height, self.__width, 4), np.uint8) back_pil = Image.fromarray(background_image) front_pil = Image.fromarray(target_image) back_pil.paste(front_pil, (left, top), front_pil) return np.array(back_pil) class Effecter: # Gauss def gauss(self, img, level): return cv2.blur(img, (level * 2 + 1, level * 2 + 1)) # Noise def noise(self, img): img = img.astype("float64") img[:, :, 0] = self.__single_channel_noise(img[:, :, 0]) img[:, :, 1] = self.__single_channel_noise(img[:, :, 1]) img[:, :, 2] = self.__single_channel_noise(img[:, :, 2]) return img.astype("uint8") def __single_channel_noise(self, single): diff = 255 - single.max() noise = np.random.normal(0, random.randint(1, 100), single.shape) noise = (noise - noise.min()) / (noise.max() - noise.min()) noise = diff * noise noise = noise.astype(np.uint8) dst = single + noise return dst # バウンディングボックス描画 def box(frame, rect, class_id): ((x1, y1), (x2, y2)) = rect label = "{}".format(CLASS_NAME[class_id]) img = cv2.rectangle(frame, (x1, y1), (x2, y2), COLORS[class_id], 2) img = cv2.rectangle(img, (x1, y1), (x1 + 150, y1 - 20), COLORS[class_id], -1) cv2.putText( img, label, (x1 + 2, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA, ) return img # 背景と商品の合成 def marge_image(background_image, front_image): back_pil = Image.fromarray(background_image) front_pil = Image.fromarray(front_image) back_pil.paste(front_pil, (0, 0), front_pil) return np.array(back_pil) # 1画像分のデータを保持するクラス class MergeData: def __init__(self, rate): self.__rects = [] self.__images = [] self.__class_ids = [] self.__rate = rate def get_class_ids(self): return self.__class_ids def max(self): return len(self.__rects) def get(self, i): return (self.__images[i], self.__rects[i], self.__class_ids[i]) # 追加(重複率が指定値以上の場合は失敗する) def append(self, target_image, rect, class_id): conflict = False for i in range(len(self.__rects)): iou = self.__multiplicity(self.__rects[i], rect) if iou > self.__rate: conflict = True break if conflict == False: self.__rects.append(rect) self.__images.append(target_image) self.__class_ids.append(class_id) return True return False # 重複率 def __multiplicity(self, a, b): (ax_mn, ay_mn) = a[0] (ax_mx, ay_mx) = a[1] (bx_mn, by_mn) = b[0] (bx_mx, by_mx) = b[1] a_area = (ax_mx - ax_mn + 1) * (ay_mx - ay_mn + 1) b_area = (bx_mx - bx_mn + 1) * (by_mx - by_mn + 1) abx_mn = max(ax_mn, bx_mn) aby_mn = max(ay_mn, by_mn) abx_mx = min(ax_mx, bx_mx) aby_mx = min(ay_mx, by_mx) w = max(0, abx_mx - abx_mn + 1) h = max(0, aby_mx - aby_mn + 1) intersect = w * h return intersect / (a_area + b_area - intersect) # 各クラスのデータ数が同一になるようにカウントする class Counter: def __init__(self, max): self.__counter = np.zeros(max) def get(self): n = np.argmin(self.__counter) return int(n) def inc(self, index): self.__counter[index] += 1 def print(self): print(self.__counter) def main(): # 出力先の初期化 if os.path.exists(OUTPUT_PATH): shutil.rmtree(OUTPUT_PATH) os.mkdir(OUTPUT_PATH) target = Target(TARGET_IMAGE_PATH, BASE_WIDTH, CLASS_NAME) background = Background(BACKGROUND_IMAGE_PATH) transformer = Transformer(BACK_WIDTH, BACK_HEIGHT) # manifest = Manifest(CLASS_NAME) counter = Counter(len(CLASS_NAME)) effecter = Effecter() no = 0 while True: # 背景画像の取得 background_image = background.get() height, width, _ = background_image.shape # mergeデータ merge_data = MergeData(0.1) label = "" for _ in range(20): # 現時点で作成数の少ないクラスIDを取得 class_id = counter.get() # 切り出しデータの取得 target_image, targhet_label = target.get(class_id) # 変換 (transform_image, rect, scale) = transformer.warp(target_image) frame = marge_image(background_image, transform_image) # 商品の追加(重複した場合は、失敗する) ret = merge_data.append(transform_image, rect, class_id) if ret: counter.inc(class_id) label += "{}".format(class_id) for i, l in enumerate(targhet_label.split(",")): if l != "": # 変換したscaleに併せてLabelも変換する z = int(float(l) * scale) # 貼り付け位置へシフトと、全体画像を基準としてた0..1の正規化 if i % 2 == 0: # X座標 x = z + rect[0][0] label += " {:.5f}".format(x / width) else: # Y座標 y = z + rect[0][1] label += " {:.5f}".format(y / height) label += "\n" print("max:{}".format(merge_data.max())) frame = background_image for index in range(merge_data.max()): (target_image, _, _) = merge_data.get(index) # 合成 frame = marge_image(frame, target_image) # アルファチャンネル削除 frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR) # エフェクト frame = effecter.gauss(frame, random.randint(0, 2)) frame = effecter.noise(frame) # 画像名 png_file_name = "{:05d}.png".format(no) # Segmentpoint label_file_name = "{:05d}.txt".format(no) no += 1 # 画像保存 cv2.imwrite("{}/{}".format(OUTPUT_PATH, png_file_name), frame) # テキスト保存 with open("{}/{}".format(OUTPUT_PATH, label_file_name), mode="w") as f: f.write(label) # manifest追加 # manifest.appned(fileName, merge_data, frame.shape[0], frame.shape[1]) for i in range(merge_data.max()): (_, rect, class_id) = merge_data.get(i) # バウンディングボックス描画(確認用) frame = box(frame, rect, class_id) counter.print() print("no:{}".format(no)) if MAX <= no: break # 表示(確認用) cv2.imshow("frame", frame) cv2.waitKey(1) main()
(4) YOLOv8形式への変換
YOLOv8で学習させるためには、学習用と検証用のデータが必要ですが、上記で作成したデータを8対2に分割して、保存します。
dataset/yolo_data/ ├── train │ ├── images │ │ ├── 00000.png │ │ ├── 00001.png ... │ │ ├── 02998.png │ │ └── 02999.png │ └── labels │ ├── 00000.txt │ ├── 00001.txt ... │ ├── 02998.txt │ └── 02999.txt └── val ├── images │ ├── 00002.png │ ├── 00006.png ... │ ├── 02989.png │ └── 02992.png └── labels ├── 00002.txt ├── 00006.txt ... ├── 02989.txt └── 02992.txt
そして、上記のパスを記述したyamlファイルも作成し、学習用のデータセットとします。(クラス名が決定されるのは、この時点となります)
data.yaml
path: /home/dataset/yolo_data train: train/images val: val/images names: 0: KINOKO 1: TAKENOKO
YOLOv8形式への変換に使用したプログラムです。
merge_data_to_yolo_data.py
""" 背景と組み合わせたデータをyolo形式に変換する """ import glob import os import shutil import numpy as np from PIL import Image def main(): input_path = "./dataset/output_merge" output_path = "./dataset/yolo_data" stages = ["train", "val"] train_path = "{}/train".format(output_path) train_images_path = "{}/images".format(train_path) train_labels_path = "{}/labels".format(train_path) val_path = "{}/val".format(output_path) val_images_path = "{}/images".format(val_path) val_labels_path = "{}/labels".format(val_path) os.makedirs(train_path, exist_ok=True) os.makedirs(train_images_path, exist_ok=True) os.makedirs(train_labels_path, exist_ok=True) os.makedirs(val_path, exist_ok=True) os.makedirs(val_images_path, exist_ok=True) os.makedirs(val_labels_path, exist_ok=True) image_files = glob.glob("{}/*.png".format(input_path)) for i, input_image_file in enumerate(image_files): basename = os.path.splitext(os.path.basename(input_image_file))[0] if i % 10 < 8: stage = stages[0] else: stage = stages[1] input_label_file = "{}/{}.txt".format(input_path, basename) output_label_file = "{}/{}/labels/{}.txt".format(output_path, stage, basename) output_image_file = "{}/{}/images/{}.png".format(output_path, stage, basename) shutil.copy(input_label_file, output_label_file) shutil.copy(input_image_file, output_image_file) if __name__ == "__main__": main()
3 学習
学習は、下記の要領です。
from ultralytics import YOLO model = YOLO("yolov8n-seg.pt") model.train(data="./dataset/yolo_data/data.yaml", epochs=10, batch=8, workers=4)
正常に学習が終わると、runs/segment/train/weights/best.ptに生成されたモデルが保存されます。
$ tree runs runs └── segment └── train ├── args.yaml ├── BoxF1_curve.png ├── BoxP_curve.png ├── BoxPR_curve.png ├── BoxR_curve.png ├── confusion_matrix_normalized.png ├── confusion_matrix.png ├── labels_correlogram.jpg ├── labels.jpg ├── MaskF1_curve.png ├── MaskP_curve.png ├── MaskPR_curve.png ├── MaskR_curve.png ├── results.csv ├── results.png ├── train_batch0.jpg ├── train_batch1.jpg ├── train_batch2.jpg ├── val_batch0_labels.jpg ├── val_batch0_pred.jpg ├── val_batch1_labels.jpg ├── val_batch1_pred.jpg ├── val_batch2_labels.jpg ├── val_batch2_pred.jpg └── weights ├── best.pt └── last.pt
4 推論
best.ptを使用して、Webカメラの画像を推論している様子です。
この時使用したプログラムです。
inference_webcam.py
# 参考にさせて頂きました # https://github.com/ultralytics/ultralytics/issues/561 from ultralytics import YOLO import numpy as np import cv2 model = YOLO("./best.pt") colors = [(0, 0, 200), (0, 200, 0)] linewidth = 5 fontScale = 1.5 fontFace = cv2.FONT_HERSHEY_SIMPLEX thickness = 5 def overlay(image, mask, color, alpha, resize=None): colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) colored_mask = np.moveaxis(colored_mask, 0, -1) masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) image_overlay = masked.filled() if resize is not None: image = cv2.resize(image.transpose(1, 2, 0), resize) image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) return image_combined def label(box, img, color, label, line_thickness=3): x1 = int(box[0]) y1 = int(box[1]) x2 = int(box[2]) y2 = int(box[3]) text_size = cv2.getTextSize( label, 0, fontScale=fontScale, thickness=line_thickness )[0] cv2.rectangle( img, (x1, y1), (x1 + text_size[0], y1 - text_size[1] - 2), color, -1 ) # fill cv2.putText( img, label, (x1, y1 - 3), fontFace, fontScale, [225, 255, 255], thickness=line_thickness, lineType=cv2.LINE_AA, ) cv2.rectangle( img, (x1, y1), (x2, y2), color, linewidth, ) cap = cv2.VideoCapture(0) if cap.isOpened() is False: raise IOError while True: try: ret, frame = cap.read() if ret is False: raise IOError h, w, _ = frame.shape results = model(frame, conf=0.80, iou=0.5) result = results[0] image = frame.copy() if result.masks is not None: for r in results: boxes = r.boxes conf_list = r.boxes.conf.tolist() for i, (seg, box) in enumerate(zip(result.masks.data.cpu().numpy(), boxes)): seg = cv2.resize(seg, (w, h)) image = overlay(image, seg, colors[int(box.cls)], 0.5) class_id = int(box.cls) box = box.xyxy.tolist()[0] class_name = result.names[class_id] label( box, image, colors[class_id], "{} {:.2f}".format(class_name, conf_list[i]), line_thickness=3, ) frame = cv2.resize(frame, (640, 480)) image = cv2.resize(image, (640, 480)) mergeImg = np.hstack((frame, image)) cv2.imshow("YOLO", mergeImg) if cv2.waitKey(1) & 0xFF == ord("q"): break except KeyboardInterrupt: break cap.release() cv2.destroyAllWindows()
参考に、./input に置かれた画像を対象に、推論するプログラムも置いておきます。
inference_images.py
# 参考にさせて頂きました # https://github.com/ultralytics/ultralytics/issues/561 from ultralytics import YOLO import numpy as np import cv2 import os import glob model = YOLO("./best.pt") colors = [(0, 0, 200), (0, 200, 0)] linewidth = 5 fontScale = 1.5 fontFace = cv2.FONT_HERSHEY_SIMPLEX thickness = 5 input_path = "./input" output_path = "./output" def overlay(image, mask, color, alpha, resize=None): colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) colored_mask = np.moveaxis(colored_mask, 0, -1) masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) image_overlay = masked.filled() if resize is not None: image = cv2.resize(image.transpose(1, 2, 0), resize) image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) return image_combined def label(box, img, color, label, line_thickness=3): x1 = int(box[0]) y1 = int(box[1]) x2 = int(box[2]) y2 = int(box[3]) text_size = cv2.getTextSize( label, 0, fontScale=fontScale, thickness=line_thickness )[0] cv2.rectangle( img, (x1, y1), (x1 + text_size[0], y1 - text_size[1] - 2), color, -1 ) # fill cv2.putText( img, label, (x1, y1 - 3), fontFace, fontScale, [225, 255, 255], thickness=line_thickness, lineType=cv2.LINE_AA, ) cv2.rectangle( img, (x1, y1), (x2, y2), color, linewidth, ) def main(): os.makedirs(output_path, exist_ok=True) files = glob.glob("{}/*.png".format(input_path)) for file in files: basename = os.path.splitext(os.path.basename(file))[0] image = cv2.imread(file) h, w, _ = image.shape results = model(image, conf=0.60, iou=0.5) result = results[0] if result.masks is not None: for r in results: boxes = r.boxes conf_list = r.boxes.conf.tolist() for i, (seg, box) in enumerate(zip(result.masks.data.cpu().numpy(), boxes)): seg = cv2.resize(seg, (w, h)) image = overlay(image, seg, colors[int(box.cls)], 0.5) class_id = int(box.cls) box = box.xyxy.tolist()[0] class_name = result.names[class_id] label( box, image, colors[class_id], "{} {:.2f}".format(class_name, conf_list[i]), line_thickness=3, ) output_filename = "{}/{}.png".format(output_path, basename) print(output_filename) cv2.imwrite(output_filename, image) main()
5 最後に
今回は、YOLOv8のセグメンテーションを試してみました。 データセットさえ作れれば、インスタンスセグメンテーションのファインチューニングも、それほど難しくないのかなと言うのが、今回作業した印象です。
やはり、問題は、如何に効率よく精度の高いデータを大量に用意するかでしょう。 SAM(Segment Anything Model)は、やはりここでもゲームチェンジャーだと思います。
やっと、完成したので、全部食べます。 どちらも好きですー \(^o^)/