Amazon Sage Makerで画像分類をしてみた(jpgから.lstファイルを生成) | Developers.IO

Amazon Sage Makerで画像分類をしてみた(jpgから.lstファイルを生成)

概要

こんにちは、yoshimです。 今回はSageMakerでビルトインアルゴリズムとして実装されている「Image-classification-lst-format」について、チュートリアルを実際にやってみます。
今回は「jpg」ファイルから「.lst」ファイルを生成し、学習を進めました。

目次

1.最初に

今回実施するチュートリアルは「画像分類」のアルゴリズムです。
「画像分類」とは、大雑把に言いますと「画像」を投入すると「その画像がなんの画像か」を推測するものです。

例えば、今回実行するチュートリアルですと、下記のような画像を投入すると

Result: label - bathtub, probability - 0.9949665665626526

こんな感じで「この画像に写っているのはバスタブだよ」、「確率は99.49..%だよ」といった結果を返してくれます。
このような学習を「自分のデータでやりたい」という方には、今回ご紹介する「jpg」から「.lst」ファイルを生成して学習する流れが、「とりあえずやってみる」という意味ではいいのではないでしょうか?
(AWSは学習に投入するデータフォーマットとして「.lst」よりも「RecordIO」フォーマットを推奨しています。)
(「image classfication」では、「.lst」を「RecordIO」に内部的に変換した後に学習を進めます)

2.「.lst」ファイルって何?

恥ずかしながら、私はこのチュートリアルに触れて初めて「.lst」ファイルというものを知りました。 そこで調べてみたところ、「.lst」ファイルとは下記のようなものだとわかりました。

・「.lst」ファイルは「タブ区切り」で「3つのカラム」からなる
・1つ目のカラムは「画像のインデックス」、2つ目のカラムは「画像のクラスのインデックス」、3つ目のカラムは「画像の相対パス」である

実際に画像データを保持するのではなく、あくまでも「どの画像がどのクラスか」といった情報を保持するファイルのようです。
この後チュートリアルを進めればわかるのですが、学習時には「.lst」ファイルと「jpg」の両方を投入します。

「.lst」ファイルの詳細については.lstファイルの説明をご参照ください。

3.実際にやってみた

今回試してみたチュートリアルは「Image-classification-lst-format.ipynb」というファイルです。  

ここにあるものと同じです。

基本的に上から順に実行していくだけで大丈夫です。
(S3のバケットだけ適宜修正してください)

3-1.事前準備

トレーニング、各データソースにアクセスする権限を取得します。
また、今回利用するS3バケットを指定します。

%%time
import boto3
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

role = get_execution_role()

bucket='hogehoge' # customize to your bucket

training_image = get_image_uri(boto3.Session().region_name, 'image-classification')
3-2.今回使うデータをダウンロード

今回のチュートリアルで利用するデータをダウンロードします。
「image-classification」では、INPUTデータとして「RecordIO (application/x-recordio)」、「イメージ (application/x-image)」のいずれかが利用可能ですが、今回は「application/x-image」のデータを使ってトレーニングを進めます。
(推論時には「イメージ (application/x-image)」のみがサポートされる点にご注意ください。)

input,outputの仕様

import os
import urllib.request

def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urllib.request.urlretrieve(url, filename)


# Caltech-256 image files
download('http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar')
!tar -xf 256_ObjectCategories.tar

# Tool for creating lst file
download('https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py')

ここでダウンロードしたtarファイルを解凍するとこんな状態です。
こんな感じで階層構造になってます。

/home/ec2-user/SageMaker/imageclassification_caltech_2018-08-08
|--256_ObjectCategories
| |--001.ak47
| | |--001_0001.jpg
| | |--001_0002.jpg
| | |--001_0003.jpg
| |--002.american-flag
| | |--002_0001.jpg
| | |--002_0002.jpg
| | |--002_0003.jpg

上記は下記のコマンドで確認しました。
(カレントディレクトリ配下のファイルの層構造を出力するものですので、適宜「cd」等で移動するか「pwd」部分を「cd」等で対象ディレクトリまで移動してください)

! pwd;find . | sort | sed '1d;s/^\.//;s/\/\([^/]*\)$/|--\1/;s/\/[^/|]*/|  /g'

また、ここでは画像データを「.lst」ファイルに変換するツールもダウンロードしています。

3-3.画像データの前処理

「jpg」ファイルから「.lst」ファイルを生成します。 p2.xlargeインスタンスで5分くらいで終わりました。

%%bash

mkdir -p caltech_256_train_60
for i in 256_ObjectCategories/*; do
    c=`basename $i`
    mkdir -p caltech_256_train_60/$c
    for j in `ls $i/*.jpg | shuf | head -n 60`; do
        mv $j caltech_256_train_60/$c/
    done
done

python im2rec.py --list --recursive caltech-256-60-train caltech_256_train_60/
python im2rec.py --list --recursive caltech-256-60-val 256_ObjectCategories/

「train」、「val」用の2つの「.lst」ファイルができました。

「.lst」ファイルの中身を確認してみましょう。

!head -n 3 ./caltech-256-60-train.lst > example.lst
f = open('example.lst','r')
lst_content = f.read()
print(lst_content)

「.lst」ファイルはこんな感じです。(実際に使うものではなく、exampleのものです)

14204 236.000000 237.vcr/237_0070.jpg
14861 247.000000 248.yarmulke/248_0062.jpg
10123 168.000000 169.radio-telescope/169_0061.jpg

「2.「.lst」ファイルって何?」では文字だけでしたが、実物を見ると理解が進みますね。

3-4.「jpg」,「.lst」ファイルをS3にアップロード

続いて、「画像ファイル」と「.lst」ファイルをS3にアップロードします。

# Four channels: train, validation, train_lst, and validation_lst
s3train = 's3://{}/train/'.format(bucket)
s3validation = 's3://{}/validation/'.format(bucket)
s3train_lst = 's3://{}/train_lst/'.format(bucket)
s3validation_lst = 's3://{}/validation_lst/'.format(bucket)

# upload the image files to train and validation channels
!aws s3 cp caltech_256_train_60 $s3train --recursive --quiet
!aws s3 cp 256_ObjectCategories $s3validation --recursive --quiet

# upload the lst files to train_lst and validation_lst channels
!aws s3 cp caltech-256-60-train.lst $s3train_lst --quiet
!aws s3 cp caltech-256-60-val.lst $s3validation_lst --quiet

余談ですが、チュートリアル中に下記のようなコメントがありました。

Now we have all the data stored in S3 bucket. The image and lst files will be converted to RecordIO file internelly by the image classification algorithm.

どうやら、「.lst」ファイルは「image classification」アルゴリズムを使っている際は内部的に「RecordIO」に変換されるみたいです。

3-5.トレーニング用のパラメータを指定

トレーニング用のハイパーパラメータを指定します。
ハイパーパラメータの一覧はこちらをご参照ください。

# The algorithm supports multiple network depth (number of layers). They are 18, 34, 50, 101, 152 and 200
# For this training, we will use 18 layers
num_layers = 18
# we need to specify the input image shape for the training data
image_shape = "3,224,224"
# we also need to specify the number of training samples in the training set
num_training_samples = 15240
# specify the number of output classes
num_classes = 257
# batch size for training
mini_batch_size = 128
# number of epochs
epochs = 6
# learning rate
learning_rate = 0.01
# report top_5 accuracy
top_k = 5
# resize image before training
resize = 256
# period to store model parameters (in number of epochs), in this case, we will save parameters from epoch 2, 4, and 6
checkpoint_frequency = 2
# Since we are using transfer learning, we set use_pretrained_model to 1 so that weights can be 
# initialized with pre-trained weights
use_pretrained_model = 1

今回はチュートリアルなので、層の深さ、エポック数等が小さめに設定されています。
SageMakerではログから「検証データセットでの精度推移」が確認できるので、とりあえずエポック数を増やして推移を見てみるのもいいかもしれません。

3-6.トレーニング用のパラメータを設定

「3-5.トレーニング用のパラメータを指定」で指定したハイパーパラメータ、学習に使うデータを格納しているS3パス、データフォーマット等を指定します。

「InputDataConfig」の「ContentType」が「application/x-image」なのがミソですね。

%%time
import time
import boto3
from time import gmtime, strftime


s3 = boto3.client('s3')
# create unique job name 
job_name_prefix = 'sagemaker-imageclassification-notebook'
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
job_name = job_name_prefix + timestamp
training_params = \
{
    # specify the training docker image
    "AlgorithmSpecification": {
        "TrainingImage": training_image,
        "TrainingInputMode": "File"
    },
    "RoleArn": role,
    "OutputDataConfig": {
        "S3OutputPath": 's3://{}/{}/output'.format(bucket, job_name_prefix)
    },
    "ResourceConfig": {
        "InstanceCount": 1,
        "InstanceType": "ml.p2.xlarge",
        "VolumeSizeInGB": 50
    },
    "TrainingJobName": job_name,
    "HyperParameters": {
        "image_shape": image_shape,
        "num_layers": str(num_layers),
        "num_training_samples": str(num_training_samples),
        "num_classes": str(num_classes),
        "mini_batch_size": str(mini_batch_size),
        "epochs": str(epochs),
        "learning_rate": str(learning_rate),
        "top_k": str(top_k),
        "resize": str(resize),
        "checkpoint_frequency": str(checkpoint_frequency),
        "use_pretrained_model": str(use_pretrained_model)    
    },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 360000
    },
#Training data should be inside a subdirectory called "train"
#Validation data should be inside a subdirectory called "validation"
#The algorithm currently only supports fullyreplicated model (where data is copied onto each machine)
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": 's3://{}/train/'.format(bucket),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "ContentType": "application/x-image",
            "CompressionType": "None"
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": 's3://{}/validation/'.format(bucket),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "ContentType": "application/x-image",
            "CompressionType": "None"
        },
        {
            "ChannelName": "train_lst",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": 's3://{}/train_lst/'.format(bucket),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "ContentType": "application/x-image",
            "CompressionType": "None"
        },
        {
            "ChannelName": "validation_lst",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": 's3://{}/validation_lst/'.format(bucket),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "ContentType": "application/x-image",
            "CompressionType": "None"
        }
    ]
}
print('Training job name: {}'.format(job_name))
print('\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))
3-7.トレーニングの開始

上記で指定した設定でトレーニングを開始します。
私の場合は、p2.xlargeで23分ほどかかりました。

# create the Amazon SageMaker training job
sagemaker = boto3.client(service_name='sagemaker')
sagemaker.create_training_job(**training_params)

# confirm that the training job has started
status = sagemaker.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))

try:
    # wait for the job to finish and report the ending status
    sagemaker.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name)
    training_info = sagemaker.describe_training_job(TrainingJobName=job_name)
    status = training_info['TrainingJobStatus']
    print("Training job ended with status: " + status)
except:
    print('Training failed to start')
     # if exception is raised, that means it has failed
    message = sagemaker.describe_training_job(TrainingJobName=job_name)['FailureReason']
    print('Training failed with the following error: {}'.format(message))

学習が終わったら、推移をコンソール画面の「ログ」から確認しましょう。
トレーニングデータでの精度推移や検証用データでの精度が確認できるので面白いです。

ログ画面を開き

ここをクリック

対象JOBのログが出てくるのでクリック

検証データ、エポックごとの精度が確認できました!

モニョモニョ精度が上がっていく推移が確認できて面白いですね。

3-8.エンドポイントに使うモデルやロールの設定

エンドポイントに利用するモデル(今回の場合は先ほどS3に出力したモデル)や、ロールを指定します。

%%time
import boto3
from time import gmtime, strftime

sage = boto3.Session().client(service_name='sagemaker') 

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
model_name="image-classification-model" + timestamp
print(model_name)
info = sage.describe_training_job(TrainingJobName=job_name)
model_data = info['ModelArtifacts']['S3ModelArtifacts']
print(model_data)

hosting_image = get_image_uri(boto3.Session().region_name, 'image-classification')

primary_container = {
    'Image': hosting_image,
    'ModelDataUrl': model_data,
}

create_model_response = sage.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container)

print(create_model_response['ModelArn'])
3-9.エンドポイントの設定

立ち上げるエンドポイントの設定を指定します。
具体的には、インスタンスタイプやインスタンスの台数を指定します。
画像分類では、推論用エンドポイントもGPUインスタンスにしないといけません。

from time import gmtime, strftime

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
endpoint_config_name = job_name_prefix + '-epc-' + timestamp
endpoint_config_response = sage.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType':'ml.p2.xlarge',
        'InitialInstanceCount':1,
        'ModelName':model_name,
        'VariantName':'AllTraffic'}])

print('Endpoint configuration name: {}'.format(endpoint_config_name))
print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))
3-10.エンドポイントの作成

「3-8.エンドポイントに使うモデルやロールの設定」、「3-9.エンドポイントの設定」で指定した設定に基づいて、推論用エンドポイントを作成します。

%%time
import time

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
endpoint_name = job_name_prefix + '-ep-' + timestamp
print('Endpoint name: {}'.format(endpoint_name))

endpoint_params = {
    'EndpointName': endpoint_name,
    'EndpointConfigName': endpoint_config_name,
}
endpoint_response = sagemaker.create_endpoint(**endpoint_params)
print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))

JOB自体はすぐに完了しますが、エンドポイントが立ち上がるまでは少し時間がかかります。コンソール画面から確認するか、下記を実行してステータスを確認しましょう。

# get the status of the endpoint
response = sagemaker.describe_endpoint(EndpointName=endpoint_name)
status = response['EndpointStatus']
print('EndpointStatus = {}'.format(status))
    
try:
    sagemaker.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)
finally:
    resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
    status = resp['EndpointStatus']
    print("Arn: " + resp['EndpointArn'])
    print("Create endpoint ended with status: " + status)

    if status != 'InService':
        message = sagemaker.describe_endpoint(EndpointName=endpoint_name)['FailureReason']
        print('Training failed with the following error: {}'.format(message))
        raise Exception('Endpoint creation did not succeed')
3-11.実際に推論してみる

作成したエンドポイントに実際に画像を投げてみて、どのように分類されるかを確認してみましょう。

セッションをはって

import boto3
runtime = boto3.Session().client(service_name='runtime.sagemaker') 

画像をダウンロードしてみましょう。

!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0007.jpg
file_name = '/tmp/test.jpg'
# test image
from IPython.display import Image
Image(file_name)  

こんな画像が表示されます。

この画像をエンドポイントに投げてみると、

import json
import numpy as np
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)
result = response['Body'].read()
# result will be in json format and convert it to ndarray
result = json.loads(result)
# the result will output the probabilities for all classes
# find the class with maximum probability and print the class index
index = np.argmax(result)
object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']
print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))

下記の通り、「バスタブである確率、99.49...%」と表示されます。

Result: label - bathtub, probability - 0.9949665665626526

いい感じですね。

3-12.推論用エンドポイントを削除

最後に、推論用エンドポイントを削除します。
画像分類で使う推論用エンドポイントはGPUなので、消し忘れると結構なコストになってしまいますから注意が必要ですね。

sage.delete_endpoint(EndpointName=endpoint_name)

4.まとめ

今回は、「SageMaker」のチュートリアルを通して「jpg」から「.lst」ファイルを作成して「画像分類」をやってみました。
手元にあるデータでとりあえず「画像分類をやってみたい」という場合は、今回ご紹介したチュートリアルを参考にやってみてもいいのではないでしょうか?

5.引用

.lstファイルの説明
caltech-256
ソースコード
input,outputの仕様
ハイパーパラメータ