Amazon SageMakerで画像分類をしてみた

2018.08.10

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

概要

こんにちは、yoshimです。 今回はSageMakerでビルトインアルゴリズムとして実装されている「Image classification transfer learning demo」について、チュートリアルを実際にやってみます。 「転移学習」をすることで、「優秀な画像分類モデルを手軽に実装」できました。

目次

1.画像分類とは

まず、そもそも「画像分類」ってどのようなことをするのでしょうか? 大雑把に言いますと「画像」を投入すると「その画像がなんの画像か」を推測するものです。

例えば、今回実行するチュートリアルですと、下記のような画像を投入すると

Result: label - umbrella-101, probability - 0.9996694326400757

こんな感じで「この画像に写っているのは傘だよ」、「確率は99.9..%だよ」といった結果を返してくれます。 最近は「自動運転技術」や「工場での不良品検出」等、様々なところで画像認識技術が研究されていますが、その一端に触れることができるのでワクワクしますね。

余談ですが、このビルトインアルゴリズムでは「ResNet」を使います。 「ResNet」はobject detectionでもbaseで使うことができます。(こちらでは層の深さは50)

2.転移学習とは

ニューラルネットワークベースのアルゴリズムはパラメータ数がとても多いため、チューニングするのがとても大変です。 「転移学習」は「ハイパーパラメータを1から学習する」のではなく、既にチューニングされているパラメータを利用して学習の手間を減らすという、とても素敵な技術なのです。

とはいえ、「転移学習を使えば新しく学習することは全くない」という訳ではないのでご注意ください。

3.実際にやってみた

今回試してみたチュートリアルは「Image-classification-transfer-learning.ipynb」というファイルです。

ここにあるものと同じです。

基本的に上から順に実行していくだけで大丈夫です。

3-1.事前準備

トレーニング、各データソースにアクセスする権限を取得します。 また、今回利用するS3バケットを指定します。

%%time
import boto3
import re
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

role = get_execution_role()

bucket='<<bucket-name>>' # customize to your bucket

training_image = get_image_uri(boto3.Session().region_name, 'image-classification') # ビルトインアルゴリズムである「image-classification」を指定

print(training_image)
3-2.今回使うデータをダウンロード

今回のチュートリアルで利用するデータをダウンロードします。 「image-classification」では、INPUTデータとして「RecordIO (application/x-recordio)」、「イメージ (application/x-image)」のいずれかが利用可能ですが、今回は「recordIO」のデータを使ってトレーニングを進めます。 ただし、推論時には「イメージ (application/x-image)」のみがサポートされる点にご注意ください。

input,outputの仕様

import os
import urllib.request
import boto3

def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urllib.request.urlretrieve(url, filename)

        
def upload_to_s3(channel, file):
    s3 = boto3.resource('s3')
    data = open(file, "rb")
    key = channel + '/' + file
    s3.Bucket(bucket).put_object(Key=key, Body=data)


# # caltech-256
download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')
download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')
upload_to_s3('validation', 'caltech-256-60-val.rec')
upload_to_s3('train', 'caltech-256-60-train.rec')
3-3.トレーニング用のパラメータを指定

トレーニング用のハイパーパラメータを指定します。 ハイパーパラメータの一覧はこちらをご参照ください。 ハイパーパラメータのチューニングについてのドキュメントをみると、もっともチューニングの効果が出そうなのは「mini_batch_size」、「learning_rate」、「optimizer」の3つらしいです。

%%time

# The algorithm supports multiple network depth (number of layers). They are 18, 34, 50, 101, 152 and 200
# For this training, we will use 18 layers
num_layers = 18
# we need to specify the input image shape for the training data
image_shape = "3,224,224"
# we also need to specify the number of training samples in the training set
# for caltech it is 15420
num_training_samples = 15420
# specify the number of output classes
num_classes = 257
# batch size for training
mini_batch_size =  128
# number of epochs
epochs = 2
# learning rate
learning_rate = 0.01
top_k=2
# Since we are using transfer learning, we set use_pretrained_model to 1 so that weights can be 
# initialized with pre-trained weights
use_pretrained_model = 1

今回はチュートリアルなので、層の深さ、エポック数等が最小限に設定されています。 SageMakerではログから「検証データセットでの精度推移」が確認できるので、とりあえずエポック数を増やして推移を見てみるのもいいかもしれません。

3-4.トレーニング用のパラメータを設定

「3-3.トレーニング用のパラメータを指定」で指定したハイパーパラメータ、学習に使うデータを格納しているS3パス、データフォーマット等を指定します。 このチュートリアルのコードをそのまま使うと、指定したバケットの直下にいきなり「bucket/train」とか「bucket/test」といったパスを作ることになってしまうので少し気になりました。

もし嫌だという場合は「prefix」とかを変数として指定して、下記のように変更するといいです。

"S3Uri": 's3://{}/train/'.format(bucket)
を
"S3Uri": 's3://{0}/{1}/train/'.format(bucket,prefix)
"S3Uri": 's3://{}/validation/'.format(bucket),
を
"S3Uri": 's3://{0}/{1}/validation/'.format(bucket,prefix),
%%time
import time
import boto3
from time import gmtime, strftime

# prefix = 'hogehoge' # 出力先S3パスを変更するなら指定

s3 = boto3.client('s3')
# create unique job name 
job_name_prefix = 'DEMO-imageclassification'
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
job_name = job_name_prefix + timestamp
training_params = \
{
    # specify the training docker image
    "AlgorithmSpecification": {
        "TrainingImage": training_image,
        "TrainingInputMode": "File"
    },
    "RoleArn": role,
    "OutputDataConfig": {
        "S3OutputPath": 's3://{}/{}/output'.format(bucket, job_name_prefix)
    },
    "ResourceConfig": {
        "InstanceCount": 1,
        "InstanceType": "ml.p2.xlarge",
        "VolumeSizeInGB": 50 
    },
    "TrainingJobName": job_name,
    "HyperParameters": {
        "image_shape": image_shape,
        "num_layers": str(num_layers),
        "num_training_samples": str(num_training_samples),
        "num_classes": str(num_classes),
        "mini_batch_size": str(mini_batch_size),
        "epochs": str(epochs),
        "learning_rate": str(learning_rate),
        "use_pretrained_model": str(use_pretrained_model)
    },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 360000
    },
#Training data should be inside a subdirectory called "train"
#Validation data should be inside a subdirectory called "validation"
#The algorithm currently only supports fullyreplicated model (where data is copied onto each machine)
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": 's3://{}/train/'.format(bucket),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "ContentType": "application/x-recordio",
            "CompressionType": "None"
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": 's3://{}/validation/'.format(bucket),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "ContentType": "application/x-recordio",
            "CompressionType": "None"
        }
    ]
}
print('Training job name: {}'.format(job_name))
print('\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))
3-5.トレーニングの開始

上記で指定した設定でトレーニングを開始します。 私の場合は、p2.xlargeで10分ほどかかりました。 注意点として、下記のJOBをjupyterの画面上から停止しても裏ではまだ実行されている場合があります。 その場合は、コンソール画面の「トレーニングジョブ」から対象のJOBを停止してください。

# create the Amazon SageMaker training job
sagemaker = boto3.client(service_name='sagemaker')
sagemaker.create_training_job(**training_params)

# confirm that the training job has started
status = sagemaker.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))

try:
    # wait for the job to finish and report the ending status
    sagemaker.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name)
    training_info = sagemaker.describe_training_job(TrainingJobName=job_name)
    status = training_info['TrainingJobStatus']
    print("Training job ended with status: " + status)
except:
    print('Training failed to start')
     # if exception is raised, that means it has failed
    message = sagemaker.describe_training_job(TrainingJobName=job_name)['FailureReason']
    print('Training failed with the following error: {}'.format(message))

また、学習の推移をコンソール画面の「ログ」から確認できます。 トレーニングデータでの精度推移が確認できるので、見ておくといいと思います。 また、最後には検証用データでの精度も計算してくれています。

ログ画面を開き

ここをクリック

対象JOBのログが出てくるのでクリック

検証データ、エポックごとの精度が確認できました!

3-6.エンドポイントに使うモデルやロールの設定

エンドポイントに利用するモデル(今回の場合は先ほどS3に出力したモデル)や、ロールを指定します。

%%time

import boto3
from time import gmtime, strftime

sage = boto3.Session().client(service_name='sagemaker') 

model_name="DEMO-image-classification-model"
print(model_name)
info = sage.describe_training_job(TrainingJobName=job_name)
model_data = info['ModelArtifacts']['S3ModelArtifacts']
print(model_data)

hosting_image = get_image_uri(boto3.Session().region_name, 'image-classification')

primary_container = {
    'Image': hosting_image,
    'ModelDataUrl': model_data,
}

create_model_response = sage.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container)

print(create_model_response['ModelArn'])
3-7.エンドポイントの設定

立ち上げるエンドポイントの設定を指定します。 具体的には、インスタンスタイプやインスタンスの台数を指定します。 A/Bテストをする場合は「ModelName」もしっかり指定してあげた方がテストがしやすいかと思います。

%%time

from time import gmtime, strftime

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
endpoint_config_name = job_name_prefix + '-epc-' + timestamp
endpoint_config_response = sage.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType':'ml.m4.xlarge',
        'InitialInstanceCount':1,
        'ModelName':model_name,
        'VariantName':'AllTraffic'}])

print('Endpoint configuration name: {}'.format(endpoint_config_name))
print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))
3-8.エンドポイントの作成

「3-6.エンドポイントに使うモデルやロールの設定」、「3-7.エンドポイントの設定」で指定した設定に基づいて、推論用エンドポイントを作成します。

%%time
import time

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
endpoint_name = job_name_prefix + '-ep-' + timestamp
print('Endpoint name: {}'.format(endpoint_name))

endpoint_params = {
    'EndpointName': endpoint_name,
    'EndpointConfigName': endpoint_config_name,
}
endpoint_response = sagemaker.create_endpoint(**endpoint_params)
print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))

JOB自体はすぐに完了しますが、エンドポイントが立ち上がるまでは少し時間がかかります。コンソール画面から確認するか、下記を実行してステータスを確認しましょう。

# get the status of the endpoint
response = sagemaker.describe_endpoint(EndpointName=endpoint_name)
status = response['EndpointStatus']
print('EndpointStatus = {}'.format(status))


# wait until the status has changed
sagemaker.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)


# print the status of the endpoint
endpoint_response = sagemaker.describe_endpoint(EndpointName=endpoint_name)
status = endpoint_response['EndpointStatus']
print('Endpoint creation ended with EndpointStatus = {}'.format(status))

if status != 'InService':
    raise Exception('Endpoint creation failed.')
3-9.実際に推論してみる

作成したエンドポイントに実際に画像を投げてみて、どのように分類されるかを確認してみましょう。

セッションをはって

import boto3
runtime = boto3.Session().client(service_name='runtime.sagemaker')

画像をダウンロードしてみましょう。

!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0007.jpg
file_name = '/tmp/test.jpg'
# test image
from IPython.display import Image
Image(file_name)

こんな画像が表示されます。

この画像をエンドポイントに投げてみると、

import json
import numpy as np
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)
result = response['Body'].read()
# result will be in json format and convert it to ndarray
result = json.loads(result)
# the result will output the probabilities for all classes
# find the class with maximum probability and print the class index
index = np.argmax(result)
object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']
print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))

下記の通り、「バスタブである確率、89.7%」と表示されます。

Result: label - bathtub, probability - 0.897666871547699

試しに適当な画像を投入してみようと思います。 ブラウザ画面上から画像をアップロードした後に、

# 適当に画像をアップロードして試してみる

file_name = '/home/ec2-user/SageMaker/blog/image_classification_transfer_learning/imageclassification_caltech_2018-08-06/image_0001.jpg'# test image
from IPython.display import Image
Image(file_name)

この傘の画像を投入してみます。

import json
import numpy as np
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)
result = response['Body'].read()
# result will be in json format and convert it to ndarray
result = json.loads(result)
# the result will output the probabilities for all classes
# find the class with maximum probability and print the class index
index = np.argmax(result)
print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))

ちゃんと傘だと認識しました。

Result: label - umbrella-101, probability - 0.9996694326400757

なお、今回学習したモデルで分類できるのは「object_categories」に列挙した物体だけ(256種)です。 (バスタブも傘もちゃんと「object_categories」に列挙されています。) 今回使ったデータセットでは、この256種類+「その他」といったくくりの全257種類の画像で学習をしていたのですが、推論のタイミングでは「その他」は利用しません。

とりあえず、256種類に含まれない画像を入れたらどんな感じかを確かめてみましょう。

file_name = '/home/ec2-user/SageMaker/blog/image_classification_transfer_learning/imageclassification_caltech_2018-08-06/image2.png'# test image
from IPython.display import Image
Image(file_name)

このアメフトボールで試してみます。

import json
import numpy as np
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)
result = response['Body'].read()
# result will be in json format and convert it to ndarray
result = json.loads(result)
# the result will output the probabilities for all classes
# find the class with maximum probability and print the class index
index = np.argmax(result)
print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
Result: label - computer-mouse, probability - 0.173785001039505

コンピュータのマウス、と推測されてしまいました。 ただ、確率は17.3%程度とのことなので実際に利用する場合には切り捨てることになりそうですね。 (実際に利用する場合は、APIを受け取る側で閾値を設定して異なる挙動を取らせることになるかと思います) 折角なので、一番確率が高いラベルだけでなく、「上から5ラベル」くらい表示させてみましょう。

%%time

'''original
上記では一番確率が高いクラスを返していたけど、トップ5くらいまで返してみよう。

'''

import json
import numpy as np
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)
result = response['Body'].read()

# result will be in json format and convert it to ndarray
result = json.loads(result)
# the result will output the probabilities for all classes
# find the class with maximum probability and print the class index
indexs = np.argsort(result)[::-1][:5] # ここを修正

# top5を表示
for index in indexs:
    print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
Result: label - computer-mouse, probability - 0.173785001039505
Result: label - tennis-ball, probability - 0.07102835923433304
Result: label - soccer-ball, probability - 0.05962461978197098
Result: label - yo-yo, probability - 0.05829168111085892
Result: label - frisbee, probability - 0.04761170595884323
CPU times: user 16 ms, sys: 0 ns, total: 16 ms
Wall time: 144 ms

コンピューターマウス、テニスボール、サッカーボール、ヨーヨー、フリスビー、という結果でした...。

楕円、もしくは丸いもの、ってことで似ていると判断しているのかもしれませんね...。

4.まとめ

転移学習を使うことで、1から学習せずに手早く優秀なモデルを利用することができます。 (美味い、安い、早い、牛丼みたいです) とはいえトレーニングやパラメータのチューニングが不要、という訳ではないのでご注意ください。

チューニングするべきパラメータは下記をご参照ください。 ハイパーパラメータ

「augmentation_type」がなんとなく便利そうでいいですね。

5.引用

ソースコード input,outputの仕様 ハイパーパラメータ caltech-256