[Amazon SageMaker] イメージ分類(Image Classification)において、データセットの解像度が、学習及び、検出結果に与える影響について確認してみました

2020.05.25

1 はじめに

CX事業本部の平内(SIN)です。

[2020/05/26 一部のデータセットに誤りがあり、結果が変化したため、本記事を修正しました]

Amazon SageMaker(以下、SageMaker)の組み込みアルゴリズムである「イメージ分類」において、データセットの解像度の違いが、学習及び、検出状況にどのように影響するかを確認してみました。

元となるデータ自体や、その適用場面によって評価は大きく変わりそうですが、あくまで一例と言う位置づけで、以下の条件で違いを確認してみました。

  • データ(静止画)は、動画から生成し各ラベル200枚とする
  • 解像度の違うデータは、動画の解像度を変換することで生成する
  • ハイパーパラメータは、同一にする

下は、対象(商品)をWebカメラで撮影して、評価している様子です。

結果は、以下のとおりです。単純にデータセットの解像度が高いほど、検出の精度が良いとは言えないようです。

解像度 accuracy epoch 学習時間(sec) 評価
800 * 600 0.93 15 1737
640 * 480 0.94 20 2295
320 * 240 0.985 20 2232
160 * 120 0.987 20 2240
80 * 60 1.0 20 2376 ×
  • ◎ 80%以上で、すべて検出
  • ○ 60%以上で、すべて検出
  • × 誤検出多数

以下、詳細です。

2 データセット

(1) 動画

動画は、下記で作成した17種類の商品用のものを再利用しました。

元々 800*600 で撮影した動画を、下記のコードで、以下の4種類の解像度に変換しました。

  • 640 ✕ 480
  • 320 ✕ 240
  • 160 x 120
  • 80 x 60
import os
import glob
import subprocess

# 設定
WIDTH = 320
HEIGHT = 240
INPUT_PATH = "/tmp/Movie/"
OUTPUT_PATH = "/tmp/Movie/{}-{}/".format(WIDTH, HEIGHT)

# 出力先ディレクトリ作成
os.makedirs(OUTPUT_PATH, exist_ok=True)

# 入力パスにある*.mp4を列挙
files = glob.glob("{}/*.mp4".format(INPUT_PATH))

for src in files:
# ffmpegで解像度変換
subprocess.call('ffmpeg -i "{}" -vf scale={}:{} "{}{}"'.format(src, WIDTH, HEIGHT, OUTPUT_PATH, os.path.basename(src)), shell=True)

(2) イメージ形式

各動画からMAX=200で、イメージ形式のデータへ変換しました。各ラベル200なので、データ数は、2720件となっています。

全データ: 3400件
PORIPPY(GREEN) (200件) => 160:40
OREO (200件) => 160:40
CUNTRY_MAM (200件) => 160:40
PORIPPY(RED) (200件) => 160:40
BANANA (200件) => 160:40
CHEDDER_CHEESE (200件) => 160:40
PRETZEL(YELLOW) (200件) => 160:40
FURUGURA(BROWN) (200件) => 160:40
NOIR (200件) => 160:40
PRIME (200件) => 160:40
CRATZ(RED) (200件) => 160:40
CRATZ(GREEN) (200件) => 160:40
PRETZEL(BLACK) (200件) => 160:40
CRATZ(ORANGE) (200件) => 160:40
ASPARA (200件) => 160:40
FURUGURA(RED) (200件) => 160:40
PRETZEL(GREEN) (200件) => 160:40

train:2720 validation:680

各解像度ごとのデータセットとしてS3にアップロードしています。

3 ハイパーパラメータ

すべての確認で、以下のハイパーパラメータを設定しています。

一旦、epoch数を30とし、early_stoppingをtrueに設定することで、早期の収束、若しくは、収束しない場合に学習を停止するようにしています。

early_stopping true
epochs 30
image_shape 3,224,224
learning_rate 0.001
mini_batch_size 32
num_classes 17
num_layers 101
num_training_samples 2720
resize 256
use_pretrained_model 1

4 各解像度ごとの確認

(1) 800 * 600

800*600のデータセットを使用した場合の学習の進行状況と、検出結果です。

epoch Train-accuracy Validation-accuracy
-----------------------------------------
0 0.218 0.153
1 0.245 0.327
2 0.591 0.665
3 0.884 0.924
4 0.979 0.903
5 0.996 0.936
6 0.999 0.921
7 1.0 0.918
8 1.0 0.874
9 1.0 0.92
10 1.0 0.918
11 1.0 0.926
12 1.0 0.881
13 1.0 0.917
14 1.0 0.93
15 1.0 0.936
Early stopping criteria met.

殆どすべてが90%以上(最も低いものでの、75%)で安定して検出しています。事後、この精度を基準に確認して行きます。

(2) 640 * 480

640*480のデータセットを使用した場合です。

epoch Train-accuracy Validation-accuracy
-----------------------------------------
0 0.235 0.17
1 0.276 0.481
2 0.733 0.777
3 0.901 0.897
4 0.994 0.875
5 1.0 0.935
6 1.0 0.936
7 1.0 0.936
8 1.0 0.886
9 1.0 0.932
10 1.0 0.933
11 1.0 0.938
12 1.0 0.901
13 0.996 0.936
14 1.0 0.943
15 1.0 0.949
16 1.0 0.906
17 1.0 0.946
18 1.0 0.946
19 1.0 0.946
20 1.0 0.903
Early stopping criteria met.

学習の収束には時間がかかるようになりましたが、800*600と殆ど変わりません。検出は、少し精度が落ちているようです。

(3) 320 * 240

epoch Train-accuracy Validation-accuracy
-----------------------------------------
0 0.295 0.281
1 0.425 0.513
2 0.843 0.783
3 0.965 0.869
4 0.999 0.896
5 1.0 0.94
6 1.0 0.935
7 1.0 0.964
8 1.0 0.952
9 1.0 0.973
10 1.0 0.984
11 1.0 0.981
12 1.0 0.98
13 1.0 0.981
14 1.0 0.981
15 1.0 0.985
16 1.0 0.972
17 1.0 0.982
18 1.0 0.976
19 1.0 0.978
20 1.0 0.973
Early stopping criteria met.

学習の収束は早く、結果も優秀です。検出の結果も800*600より高くなっています。

(4) 160 * 120

epoch Train-accuracy Validation-accuracy
-----------------------------------------
0 0.358 0.17
1 0.509 0.551
2 0.865 0.917
3 0.981 0.954
4 1.0 0.953
5 1.0 0.961
6 1.0 0.963
7 1.0 0.964
8 1.0 0.972
9 1.0 0.975
10 1.0 0.973
11 1.0 0.978
12 1.0 0.979
13 1.0 0.981
14 1.0 0.979
15 1.0 0.984
16 1.0 0.984
17 1.0 0.985
18 1.0 0.984
19 1.0 0.985
20 1.0 0.987
Early stopping criteria met.

学習の経過を見ると、結果が上がっているように見えますが、検出の結果は、800*600よりも少し落ちています。

(5) 80 * 60

epoch Train-accuracy Validation-accuracy
-----------------------------------------
0 0.424 0.207
1 0.523 0.433
2 0.84 0.993
3 0.968 0.996
4 0.994 1.0
5 1.0 1.0
6 1.0 1.0
7 1.0 1.0
8 1.0 1.0
9 1.0 1.0
10 1.0 1.0
11 1.0 1.0
12 1.0 1.0
13 1.0 1.0
14 1.0 1.0
15 1.0 1.0
16 1.0 1.0
17 1.0 1.0
18 1.0 1.0
19 1.0 1.0
20 1.0 1.0
Early stopping criteria met.

学習は、非常に簡単に収束し、数値だけみると素晴らしいモデルが出来上がっているように見えますが、実際に検出してみると、まったく精度がでません。このような解像度のデータセットは、止めた方が安全のようです。

5 評価プログラム

モデルを評価するために、商品の検出を表示したプログラムです。

70%以上の確率で正しく検出できたものを水色、確率が80%以下のものを黄色で表示しています。 また、誤検知は、赤色です。

下部に、それぞれの検出結果をテキストで表示しています。

"""
[イメージ分類] データセットの解像度による検出の状況について確認する
"""

import json
import datetime
import cv2
from boto3.session import Session

PROFILE = 'developer'
END_POINT = 'sampleEndPoint'
CLASSES = ['PORIPPY(GREEN)', 'OREO', 'CUNTRY_MAM', 'PORIPPY(RED)', 'BANANA'
, 'CHEDDER_CHEESE', 'PRETZEL(YELLOW)', 'FURUGURA(BROWN)', 'NOIR'
, 'PRIME', 'CRATZ(RED)', 'CRATZ(GREEN)', 'PRETZEL(BLACK)', 'CRATZ(ORANGE)'
, 'ASPARA', 'FURUGURA(RED)', 'PRETZEL(GREEN)']
DEVICE_ID = 0 # Webカメラ

HEIGHT = 600
WIDTH = 800
AREA_WIDTH = 106
AREA_HEIGHT = 170
H_BIAS = 190
W_BIAS = -2

class SageMaker():
def __init__(self, profile, endPoint):
self.__end_point = endPoint
self.__client = Session(profile_name=profile).client('sagemaker-runtime')

def invoke(self, image):
data = self.__client.invoke_endpoint(
EndpointName=self.__end_point,
Body=image,
ContentType='image/jpeg'
)
return json.loads(data['Body'].read())

def putText(frame, x, y, text):
font_size = 0.4
font_color = (255, 255, 255)
cv2.putText(frame, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, font_size, font_color, 1, cv2.LINE_AA)

def onClick(event, x, y, flags, frame):
if event == cv2.EVENT_LBUTTONUP:
now = datetime.datetime.now()
cv2.imwrite(str(now) + '.jpg', frame)
print("Saved.")

def createArea(width, height, w, h, h_bias, w_bias):
w_bias = (w_bias) * 110
x_1 = int(width/2-w/2) - w_bias
x_2 = int(width/2+w/2) - w_bias
y_1 = int(height/2-h/2) - h_bias
y_2 = int(height/2+h/2) - h_bias
return [x_1, y_1, x_2, y_2]

def main():

cap = cv2.VideoCapture(DEVICE_ID)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)

fps = cap.get(cv2.CAP_PROP_FPS)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
print("FPS:{} WIDTH:{} HEIGHT:{}".format(fps, width, height))

sageMake = SageMaker(PROFILE, END_POINT)

# 商品のエリア
areas = []
for w_bias in [2, 1, 0, -1, -2]:
areas.append(createArea(width, height, AREA_WIDTH, AREA_HEIGHT, H_BIAS, w_bias))
for w_bias in [2, 1, 0, -1, -2]:
areas.append(createArea(width, height, AREA_WIDTH, AREA_HEIGHT, H_BIAS - 220, w_bias))
# 正解ラベル
labels = ['BANANA', 'PRIME', 'CHEDDER_CHEESE', 'CUNTRY_MAM', 'OREO', 'CRATZ(ORANGE)', 'NOIR', 'CRATZ(GREEN)', 'ASPARA', 'PORIPPY(RED)']

while True:

# カメラ画像取得
_, frame = cap.read()
if(frame is None):
continue

# 対象範囲の枠表示
for i, area in enumerate(areas):
# 対象エリアの抽出
img = frame[area[1]: area[3], area[0]: area[2]]
_, jpg = cv2.imencode('.jpg', img)

# 推論
results = sageMake.invoke(jpg.tostring())
probabilitys = {}
for e, result in enumerate(results):
probabilitys[CLASSES[e]] = result
probabilitys = sorted(probabilitys.items(), key = lambda x:x[1], reverse=True)
(name, probability) = probabilitys[0]

# 枠の描画
color = (0, 0, 255) # ラベルが正しくない場合
if(name == labels[i]):
if(probability > 0.7):
color = (255, 255, 0) # ラベルが正しく、確率が0.7を超える場合
else:
color = (0, 255, 255) # ラベルは正しいが、確率が0.7以下の場合
frame = cv2.rectangle(frame, (area[0], area[1]), (area[2], area[3]), color, 3)

# 結果(テキスト)表示
text = "{} {}".format(name, probability)
putText(frame, 80, int(height) - 170 + i*18, text)

# フレーム表示
cv2.imshow('frame', frame)

# マウスクリックでスナップ撮影
cv2.setMouseCallback('frame', onClick, frame)

cv2.waitKey(1)

cap.release()
cv2.destroyAllWindows()

main()

6 最後に

結果、320✕240が、最も学習が収束しやすく、検出結果も良い数値となりました。 これは、畳み込みニューラルネットワークの入力レイヤで、デフォルト値の(3, 224, 224)となっているのと関係があるのでしょうか・・・

なお、一定以下の解像度(今回は、80✕60)では、学習は、簡単に収束して高いスコアとなりますが、検出自体は、誤検出となってしまっているので注意が必要です。

データを扱うという観点では、解像度が低いほど、データサイズは小さくなり、変換やコピーなどが楽になるので、224✕224あたりを基準にすると良いのかも知れません。