【Swift】簡単なリアルタイムに画像を分類するアプリを作ってみた

2021.07.07

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

妻に「レタスを買ってきて」と頼まれてキャベツを買って帰って怒られることが頻繁にあったのでこの問題を何とか解決できないか考えていました。

そこで調べてみると、Appleが公式に交付している機械学習モデルにResnet50というものがあり、

説明:木、動物、食べ物、乗り物、人など、1000のカテゴリのセットから画像に存在する主要なオブジェクトを検出します。

この機械学習モデルを使って、レタスとキャベツの区別が出来れば今後妻に怒られることはないかもしれないと思い、一筋の希望を抱いてカメラ撮影を行い、そのフレーム内の画像を分類し表示するアプリを作ってみる事にしました。

作ったもの

環境

  • Xcode 12.5
  • Swift 5.4

処理の流れ

ざっとの説明になりますが、

  1. カメラでビデオ撮影
  2. ビデオ撮影したものからフレームをキャプチャし、イメージを切り出す
  3. 切り出したイメージから機械学習モデルを使用し、画像を分類する
  4. 画像分類結果のIDと信頼値をUIに描画する

といった流れになっています。

VideoCapture

カメラの映像をキャプチャするクラスを作っていきたいと思います。

import AVFoundation

protocol VideoCaptureDelegate: AnyObject {
    func didSet(_ previewLayer: AVCaptureVideoPreviewLayer)
    func didCaptureFrame(from imageBuffer: CVImageBuffer)
}

class VideoCapture: NSObject {

    weak var delegate: VideoCaptureDelegate?

    // AVの出入力のキャプチャを管理するセッションオブジェクト
    private let captureSession = AVCaptureSession()

    // ビデオを記録し、処理を行うためにビデオフレームへのアクセスを提供するoutput
    private let videoOutput = AVCaptureVideoDataOutput()

    // カメラセットアップとフレームキャプチャを処理する為のDispathQueue
    private let sessionQueue = DispatchQueue(label: "object-detection-queue")

    func startCapturing() {

        // capture deviceのメディアタイプを決め、
        // 何からインプットするかを決める
        guard let captureDevice = AVCaptureDevice.default(for: .video),
              let deviceInput = try? AVCaptureDeviceInput(device: captureDevice)
        else { return }

        // captureSessionにdeviceInputを入力値として入れる
        captureSession.addInput(deviceInput)

        // キャプチャセッションの開始
        captureSession.startRunning()

        // ビデオフレームの更新ごとに呼ばれるデリゲートをセット
        videoOutput.setSampleBufferDelegate(self, queue: sessionQueue)

        // captureSessionから出力を取得するためにdataOutputをセット
        captureSession.addOutput(videoOutput)

        // captureSessionをUIに描画するためにPreviewLayerにsessionを追加
        let previewLayer = AVCaptureVideoPreviewLayer(session: captureSession)
        delegate?.didSet(previewLayer)
    }

    func stopCapturing() {
        // キャプチャセッションの終了
        captureSession.stopRunning()
    }
}

以下に詳細を説明します。

プロパティ

Delegate

weak var delegate: VideoCaptureDelegate?

AVCaptureSession

// AVの出入力のキャプチャを管理するセッションオブジェクト
private let captureSession = AVCaptureSession()

AVCaptureVideoDataOutput()

// ビデオを記録し、処理を行うためにビデオフレームへのアクセスを提供するoutput
private let videoOutput = AVCaptureVideoDataOutput()

SessionQueue

// カメラセットアップとフレームキャプチャを処理する為のDispathQueue
private let sessionQueue = DispatchQueue(label: "object-detection-queue")

キャプチャの開始

キャプチャの開始にはstartCapturing()メソッドを使用します。

func startCapturing() {

    // capture deviceのメディアタイプを決め、
    // 何からインプットするかを決める
    guard let captureDevice = AVCaptureDevice.default(for: .video),
          let deviceInput = try? AVCaptureDeviceInput(device: captureDevice)
    else { return }

    // captureSessionにdeviceInputを入力値として入れる
    captureSession.addInput(deviceInput)

    // キャプチャセッションの開始
    captureSession.startRunning()

    // ビデオフレームの更新ごとに呼ばれるデリゲートをセット
    videoOutput.setSampleBufferDelegate(self, queue: sessionQueue)

    // captureSessionから出力を取得するためにdataOutputをセット
    captureSession.addOutput(videoOutput)

    // captureSessionをUIに描画するためにPreviewLayerにsessionを追加
    let previewLayer = AVCaptureVideoPreviewLayer(session: captureSession)
    delegate?.didSet(previewLayer)
}

まず最初に、captureDeviceとして、AVCaptureDevice.default(for: .video)を設定しています。

AVCaptureDevice.default(for:)では様々なAVMediaTypeを指定できますが、今回はvideoからキャプチャしたので.videoにしています。

captrureDeviceからcaptureSessionへ取得データを提供する為に、AVCaptureDeviceInput(device: captureDevice)を準備しておきます。

addInput(_ input)

captureSession.addInput(_ input:)で準備したdeviceInputcaptureSessionに追加します。

startRunning()

インプットデータも決まったので、captureSession.startRunning()captureSessionの実行をスタートさせます。

setSampleBufferDelegate(_:queue:)

カメラのビデオをキャプチャし、そのフレームが更新される度に画像の分類やViewへの描画を行いたいので、setSampleBufferDelegate(_:queue:)をセットしてデリゲートメソッドを使えるようにしておきます。

addOutput(_ output)

captureSession.addOutput(_ output)captureSessionの出力にvideoOutputを追加します

AVCaptureVideoPreviewLayer

AVCaptureVideoPreviewLayerは、キャプチャされたビデオを表示するレイヤーで、今回はパラメーターにカメラからのキャプチャSessionとしてcaptureSessionを渡し、プレビューレイヤーをイニシャライズしています。

let previewLayer = AVCaptureVideoPreviewLayer(session: captureSession)

previewLayerを定義したら、そのpreviewLayerVideoCaptureDelegateメソッドのdidSet(_ previewLayer: AVCaptureVideoPreviewLayer)に渡しています。

キャプチャの停止

stopRunnningcaptureSessionを停止します。

func stopCapturing() {
    // キャプチャセッションの終了
    captureSession.stopRunning()
}

AVCaptureVideoDataOutputSampleBufferDelegate

ビデオデータ出力からサンプルバッファを受信し、そのステータスを監視できるメソッドで、その中のcaptureOutput(_ output: , didOutput sampleBuffer: , from connection:)で新しいビデオフレームが書き込まれたことを検知することができます。

extension VideoCapture: AVCaptureVideoDataOutputSampleBufferDelegate {

    func captureOutput(_ output: AVCaptureOutput,
                       didOutput sampleBuffer: CMSampleBuffer,
                       from connection: AVCaptureConnection) {

        // フレームからImageBufferに変換
        guard let imageBuffer = CMSampleBufferGetImageBuffer(sampleBuffer)
        else { return }

        delegate?.didCaptureFrame(from: imageBuffer)
    }
}

このパラメーターCMSampleBufferには、ビデオフレームのデータやフレームに関する情報が含まれており、ここからキャプチャしたフレームをImageBufferに変換しています。

そして、ImageBufferVideoCaptureDelegateメソッドのdidCaptureFrame(from imageBuffer: CVImageBuffer)に渡しています。

Resnet50ModelManager

このクラスでは、機械学習モデルResnet50を使用して、画像を分類する処理を管理しています。

Resnet50は、Apple公式 CoreML modelでダウンロードできます!

import CoreML
import Vision

protocol Resnet50ModelManagerDelegate: AnyObject {
    func didRecieve(_ observation: VNClassificationObservation)
}

class Resnet50ModelManager: NSObject {

    weak var delegate: Resnet50ModelManagerDelegate?

    func performRequet(with imageBuffer: CVImageBuffer) {
        // 機械学習モデル
        guard let model = try? VNCoreMLModel(for: Resnet50(configuration: .init()).model)
        else { return }

        // フレーム内で機械学習モデルを使用した画像分析リクエスト
        let request = VNCoreMLRequest(model: model) { request, error in
            if let error = error {
                print(error.localizedDescription)
                return
            }

            guard let results = request.results as? [VNClassificationObservation],
                  let firstObservation = results.first
            else { return }

            self.delegate?.didRecieve(firstObservation)
        }

        // imageRequestHanderにimageBufferをセット
        let imageRequestHandler = VNImageRequestHandler(cvPixelBuffer: imageBuffer)
        // imageRequestHandlerにrequestをセットし、実行
        try? imageRequestHandler.perform([request])
    }
}

以下で詳細を説明します。

Delegate

weak var delegate: Resnet50ModelManagerDelegate?

performRequest

performRequet(with imageBuffer:)で機械学習のリクエストを実行します。

機械学習モデルの用意

まずは今回使用するResnet50のモデルを定義します。

guard let model = try? VNCoreMLModel(for: Resnet50(configuration: .init()).model)
else { return }

VNCoreMLRequestの設定

// フレーム内で機械学習モデルを使用した画像分析リクエスト
let request = VNCoreMLRequest(model: model) { request, error in
    if let error = error {
        print(error.localizedDescription)
        return
    }

guard let results = request.results as? [VNClassificationObservation],
      let firstObservation = results.first
else { return }

self.delegate?.didRecieve(firstObservation)
}

VNCoreMLRequest(model:)に今回はResnet50モデルを設定することで、Reset50を使用した画像分析リクエストを取得できるようになります。

VNClassificationObservationには、画像分析リクエストが生成した分類情報、観測値が含まれており、取得した観測値ObservationResnet50ModelManagerDelegatedidRecieve(_ observation:)のパラメーターとして渡しています。

VNImageRequestHandler

単一の画像に対してVisionリクエストを実行するハンドラーで、VNImageRequestHandler(cvPixelBuffer:)で設定した画像に対してリクエスト処理を実行します。

// imageRequestHanderにimageBufferをセット
let imageRequestHandler = VNImageRequestHandler(cvPixelBuffer: imageBuffer)
// imageRequestHandlerにrequestをセットし、実行
try? imageRequestHandler.perform([request])

初期化した時点ではまだどのリクエスト処理を実行するかを指定していないので、imageRequestHandler.perform([VNRequest])で実行するリクエストを指定して処理を開始します。

ちなみにCVPixelBufferCVImageBuffertypealiasになります。

typealias CVPixelBuffer = CVImageBuffer

RealTimeImageClassficationViewController

あとは、ViewControllerの実装を行なっていきます。

import UIKit
import AVFoundation
import Vision

class RealTimeImageClassficationViewController: UIViewController {

    @IBOutlet private weak var previewView: UIView!
    @IBOutlet private weak var observationLabel: UILabel!

    private let videoCapture = VideoCapture()
    private let resnet50ModelManager = Resnet50ModelManager()

    override func viewDidLoad() {
        super.viewDidLoad()
        resnet50ModelManager.delegate = self
        videoCapture.delegate = self
        videoCapture.startCapturing()
    }

    override func viewWillDisappear(_ animated: Bool) {
        super.viewWillDisappear(animated)
        videoCapture.stopCapturing()
    }
}

extension RealTimeImageClassficationViewController: VideoCaptureDelegate {

    // previewLayerがセットされた時に呼ばれる
    func didSet(_ previewLayer: AVCaptureVideoPreviewLayer) {
        previewView.layer.addSublayer(previewLayer)
        previewLayer.frame = previewView.frame
    }

    // フレームがキャプチャされる度に呼ばれる
    func didCaptureFrame(from imageBuffer: CVImageBuffer) {
        resnet50ModelManager.performRequet(with: imageBuffer)
    }
}

extension RealTimeImageClassficationViewController: Resnet50ModelManagerDelegate {

    // 画像分析リクエストから観測値を受け取る度に呼ばれる
    func didRecieve(_ observation: VNClassificationObservation) {

        DispatchQueue.main.async {
            self.observationLabel.text = "\(observation.confidence.convertPercent)%の確率で、\(observation.identifier)"
            print(observation.identifier, observation.confidence)
        }
    }
}

以下で詳細を説明します。

viewDidLoad

Delegateの設定とVideoCaptureのキャプチャを開始しています。

override func viewDidLoad() {
    super.viewDidLoad()
    resnet50ModelManager.delegate = self
    videoCapture.delegate = self
    videoCapture.startCapturing()
}

viewWillDisappear

VideoCaptureのキャプチャを停止しています。

override func viewWillDisappear(_ animated: Bool) {
    super.viewWillDisappear(animated)
    videoCapture.stopCapturing()
}

VideoCaptureDelegate

didSet(_ previewLayer:)

このメソッドは、VideoCaptureクラスでpreviewLayerが設定された時に呼ばれます。

func didSet(_ previewLayer: AVCaptureVideoPreviewLayer) {
    previewView.layer.addSublayer(previewLayer)
    previewLayer.frame = previewView.frame
}

設定されたpreviewLayerpreviewView.layerに追加して、キャプチャのイメージをUIに反映しています。 またframepreviewViewに合わせています。

didCaptureFrame(from imageBuffer:)

このメソッドは、VideoCaptureクラスでフレームがキャプチャされる度に呼ばれます。

func didCaptureFrame(from imageBuffer: CVImageBuffer) {
    resnet50ModelManager.performRequet(with: imageBuffer)
}

キャプチャしたイメージをパラメーターとして持っているので、そのパラメーターをresnet50ModelManager.performRequet(with:)に渡して、画像解析クエスト処理を実行しています。

Resnet50ModelManagerDelegate

didRecieve(_ observation:)

このメソッドはResnet50ModelManagerクラス内で画像分析リクエストから観測値を受け取る度に呼ばれます。

func didRecieve(_ observation: VNClassificationObservation) {

    DispatchQueue.main.async {
        self.observationLabel.text = "\(observation.confidence.convertPercent)%の確率で、\(observation.identifier)"
        print(observation.identifier, observation.confidence)
    }
}

observation.confidenceはその観測値の信頼値を持っており、0.0~1.0の値を持っています。1.0に近づけば近づくほどその画像分類結果の信頼値が高いことを示します。

observation.identifierには画像分類した画像を示す文字列が入っております。

そのconfidenceidentifierを一つの文字列にしてUILabeltextに代入しています。

VNConfidence+convertPercent

confidenceの値をパーセントに変換する為のエクステンションです。

extension VNConfidence {
    var convertPercent: Int {
        return Int(self * 100)
    }
}

おわりに

結論から言うと、レタスとキャベツを分類することが出来ませんでした、、

キャベツは分類できたのですが、レタスもキャベツと分類してしまい、この機械学習モデルも私と同じ過ちをしていました。 それはそれで何故か自分に自信が持てた気がします。笑

そもそも今回使用したResnet50モデルは全てのオブジェクトを網羅しているわけではなく、

1000のカテゴリのセットから画像に存在する主要なオブジェクトを検出します。

とのことなので、満足のいく結果を得る為には自分で機械学習モデルを作る必要がありそうです。 それはそれで面白そうですね。

キャベツ&レタスとの戦いはまだまだ続く、、、、(多分)

参考