MMDetection – YOLOXを使って自作データセットで物体検出してみた

MMDetection – YOLOXを使って自作データセットで物体検出してみた

Clock Icon2022.07.11

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんちには。

データアナリティクス事業本部機械学習チームの中村です。

今回は、以下の記事で紹介したMMDetectionを使って自作データのトレーニング方法をご紹介します。

最初に結果から

ゴールとしては以下のようなメロンソーダアメとオレンジアメを物体検出します。

実行環境

今回はGoogle Colaboratory環境で実行しました。

ハードウェア情報は以下の通りです。

  • GPU: Tesla T4 (GPUメモリ16GB搭載)
  • メモリ: 26GB

主なソフトウェア・ライブラリのバージョンは以下となります。

  • CUDA: 11.1
  • PyTorch: 1.11.0+cu113
  • MMDetection: 2.25.0

またデータを永続化するために、Google Drive上に作業ディレクトリを実施します。

この記事の通りに学習すると、約20GB程度を使用するため空き容量を考慮して実行してください。

使用するモデル

YOLOXを使って学習します。サイズが以下にあるように複数ありますが、今回はYOLOX-lを使用します。

データセットの準備

撮影

iphoneで以下のような画像を撮影しました。

IMG_XXXX.JPGという名前になりますので、ffmpegでファイル名を日時情報に修正します。

ffmpegは以下からダウンロードします。

ついでにそのままだと画素数が大きめなので、デモ画像と同じwidth=640pixelに縮小します。

以下のシェルを使用しました。

#!/bin/bash

INPUT_DIR="."
OUTPUT_DIR="."

for f in `find . | grep "JPG"`; do
    datetime=`ffprobe -show_frames -hide_banner ${f} 2>&1 \
        | grep "DateTimeOriginal" \
        | awk -F"=" '{print $2}' \
        | sed -e "s/:/-/g" -e "s/ /_/g"`.jpg
    echo ${datetime}
    ffmpeg -i ${f} -vf scale=640:-1 ${datetime}
done

-show_framesオプションでExif情報を取得できます。その中で、DateTimeOriginal、DateTimeDigitizedが撮影日時になっています。 何か経由でアップロードした場合などは、このExif情報がプライバシーの関係で削除される可能性がありますので、ご注意ください。

面倒な場合は特にリネーム等せずにffmpeg -i ${f} -vf scale=640:-1 ${datetime}の行の縮小処理のみ実行してください。

アノテーション

次に撮影したデータのラベル付けをアノテーションツールを使って行います。

ツールは、CVATを使用します。

公式のインストールガイドは以下となります。

今回はこの中でDocker環境を使って構築しました。

実際の使い方については以下の記事も参考にされてください。

アノテーション結果の出力フォーマットですが、MMDetectionは、COCOフォーマットに対応・推奨していますので、COCOフォーマットを指定して出力します。

最終的には、以下のようなフォルダ構造となるように構成します。

└─candy
    ├─train
    │  ├─annotations
    │  │      instances_default.json
    │  └─images
    │          2022-07-07_07-40-32.jpg
    │          ...
    └─valid
        ├─annotations
        │      instances_default.json
        └─images
                2022-07-07_09-43-43.jpg
                ...

またアノテーションした結果としてサンプル数は以下のようになりました。

set JPGファイル数 物体数(melon-soda) 物体数(orange)
train 12 28 153
valid 6 4 29

Google Colaboratoryでの実行

セットアップ

データを永続化するため、ストレージはGoogle Driveを使用します。

実行前に、Google Driveをマウントします。

そして以下のコマンドで作業ディレクトリをGoogle Driveに変更します。

%cd /content/drive/MyDrive

ここで、各種ライブラリをインストールします。

!pip3 install openmim
!mim install mmcv-full
!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection
!pip install -e .

再実行する場合は、上記のうち!git clone ...のみコメントアウトして実行してください。

学習済みモデルの取得

YOLOXの学習済みモデルを取得します。

リンクについては今後変わる可能性がありますので、以下のページで確認してください。

本記事執筆時点では、以下のリンクでダウンロードできました。

!wget https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth -P ./checkpoints

データセットの配置

アノテーション済みのデータセットを以下のフォルダに配置します。

dataset/candy

configファイルの修正

YOLOXのconfigファイルを以下から以下に複製します。

%cp configs/yolox/yolox_s_8x8_300e_coco.py dataset/candy/yolox_l_8x8_300e_candy.py

ここで本来はYOLOX-lのconfigファイルを使用したかったのですが、YOLOX-lはYOLOX-sのパラメータをオーバーライドしているため、元となったYOLOX-sを複製元にしています。

以降、configファイルの修正点について、解説しながら記載していきます。

configファイルの置き場所がかわっているので、_base_部分は以下のように書き換えます。

_base_ = ['../../configs/_base_/schedules/schedule_1x.py', 
    '../../configs/_base_/default_runtime.py']

モデルの部分は、以下のポイントを修正します。

  • YOLOX-l向けにチャンネル数等を修正
  • head部分の最終的な出力を示すnum_classesを2に書き換え

以下が編集内容です。編集前の部分は参考のためコメントアウトしています。

# model settings
model = dict(
    type='YOLOX',
    input_size=img_scale,
    random_size_range=(15, 25),
    random_size_interval=10,
    # backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
    backbone=dict(type='CSPDarknet', deepen_factor=1.0, widen_factor=1.0),
    neck=dict(
        type='YOLOXPAFPN',
        in_channels=[256, 512, 1024], # in_channels=[128, 256, 512],
        out_channels=256, # out_channels=128,
        num_csp_blocks=3), # num_csp_blocks=1),
    bbox_head=dict(
        # type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
        type='YOLOXHead', num_classes=2, in_channels=256, feat_channels=256),
    train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
    # In order to align the source code, the threshold of the val phase is
    # 0.01, and the threshold of the test phase is 0.001.
    test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

この数値の意味については、物体検出のアーキテクチャにある程度慣れておく必要がありますが、以下に簡単に解説しておきます。

近年の物体検出の大枠の処理ブロックは以下のような構成です。

  • backbone: 画像を複数解像度の特徴量マップに変換するブロック
  • neck: 各特徴量マップを混合するブロック
  • head: 最終的な検出に向けた変換ブロック

最終出力ブロックはheadですので、その部分に検出対象のクラス数がnum_classesでパラメータ化されています。

各クラスがどの種類の物体を表すかは後述のclassesパラメータで設定します。

そして各ブロックの表現力をdeepen_factor, widen_factor, 各channels, num_csp_blocksなどでコントロールします。

配置したデータセットに合わせて、data_rootを変更します。

# dataset settings
data_root = 'dataset/candy/'

各クラスのラベルを設定する必要があるため、以下を追記します。

今回は2種類のアメがあるので、それらを設定します。

classes = ('melon-soda', 'orange')

train_datasetclassesの指定と、データセットのdata_root以降のパスを修正します。

train_dataset = dict(
    type='MultiImageMixDataset',
    dataset=dict(
        type=dataset_type,
        classes=classes,
        ann_file=data_root + 'train/annotations/instances_default.json',
        img_prefix=data_root + 'train/images/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True)
        ],
        filter_empty_gt=False,
    ),
    pipeline=train_pipeline)

またvalidとtestのセットも同様に修正します。

更にここでは、使用するハードウェアに合わせて、sample_per_gpuworkers_per_gpuを減らました。

(デフォルトではGPUメモリが不足してエラーとなりました)

data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    persistent_workers=True,
    train=train_dataset,
    val=dict(
        type=dataset_type,
        classes=classes,
        ann_file=data_root + 'valid/annotations/instances_default.json',
        img_prefix=data_root + 'valid/images/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        classes=classes,
        ann_file=data_root + 'valid/annotations/instances_default.json',
        img_prefix=data_root + 'valid/images/',
        pipeline=test_pipeline))

最後に流用するモデルの重みファイルを以下のように追記します。

# We can use the pre-trained model to obtain higher performance
load_from = 'checkpoints/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'

ちなみに流用せずしなくても学習自体はできますが、今回は流用した方が精度が良くなりました。

以上で、configファイルの修正は完了です。

学習の実行

configファイルの修正ができれば、以下で学習を行うことができます。

!python tools/train.py dataset/candy/yolox_l_8x8_300e_candy.py

前述の実行環境では約30分で学習が終わりました。

推論テスト

まずはvalid用のデータで推論テストをします。

work_dirsにモデルが出力されているので、読み込みを行います。

config_file = 'work_dirs/yolox_l_8x8_300e_candy/yolox_l_8x8_300e_candy.py'
checkpoint_file = 'work_dirs/yolox_l_8x8_300e_candy/latest.pth'

device = 'cuda:0'
# device = 'cpu'
model = init_detector(config_file, checkpoint_file, device=device)

以下のコードで推論結果を出力します。

import pathlib

for f in pathlib.Path('dataset/candy/valid/images').glob('*jpg'):

    result = inference_detector(model, f)

    model.show_result(f,
        result, 
        out_file=pathlib.Path('work_dirs/yolox_l_8x8_300e_candy/inference/valid') / f.name
    )

以下のように検出ができました。

また動画データの推論も以下のコードで試すことができます。

import mmcv
from tqdm import tqdm

video = mmcv.VideoReader('dataset/candy/test/2022-07-09_09-54-05.mp4')

for frame_num, frame in tqdm(enumerate(video), ascii=True, total=len(video)):
    result = inference_detector(model, frame)
    model.show_result(frame, result, 
        out_file=pathlib.Path(f'work_dirs/yolox_l_8x8_300e_candy/inference/test/2022-07-09_09-54-05/frame-{frame_num:08d}.jpg')
    )

mmcv.frames2video(
    'work_dirs/yolox_l_8x8_300e_candy/inference/test/2022-07-09_09-54-05', 
    'work_dirs/yolox_l_8x8_300e_candy/inference/test/2022-07-09_09-54-05.mp4',
    filename_tmpl='frame-{:08d}.jpg'
)

動画データはdataset/candy/test/に配置しました。

また直接出力するAPIが見当たらなかったため、一旦各フレームの静止画を出力し、mmcv.frames2videoで動画に変換しています。

処理結果をお見せできることができませんが、推論がきちんと動作していました。

参考までに、動画ファイルのリネーム処理は静止画データと少し異なりますので、ここに記載しておきます。

INPUT_DIR="."
OUTPUT_DIR="."

for f in `find . | grep "MOV"`; do
    datetime=`ffprobe -hide_banner ${f} 2>&1 \
        | grep "com.apple.quicktime.creationdate:" \
        | sed -e "s/+/ /g" \
        | awk '{print $2}' \
        | sed -e "s/:/-/g" -e "s/T/_/g"`.mp4
    echo ${datetime}
    # cp ${f} ${datetime}
    ffmpeg -i ${f} -vf scale=640:-1 ${datetime}
done

まとめ

いかがでしたでしょうか。

configファイルの修正には少し知識が必要になりますが、わかりやすいパラメータ名になっていますので、名称に慣れればそこまで苦労せずに使うことができるのではないかと思います。

また編集に慣れておけば、他の物体検出モデルやInstanse Segmentationのモデルも、MMDetectionを使って同様に使っていけるので、覚えておいても損は無さそうです。

configファイルに書かれているアーキテクチャ名を論文等で調べていけば、物体検出の理論的な知識を得ることも可能ですので、興味がある方はぜひ調べてみてください。

今後のブログの記事でも、上記の理論的な話含めて、デプロイ方法などについても取り扱っていけたらと思います。

この記事をシェアする

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.