Amazon Sage Makerで画像分類をしてみた(jpgから.lstファイルを生成)
概要
こんにちは、yoshimです。 今回はSageMakerでビルトインアルゴリズムとして実装されている「Image-classification-lst-format」について、チュートリアルを実際にやってみます。 今回は「jpg」ファイルから「.lst」ファイルを生成し、学習を進めました。
目次
1.最初に
今回実施するチュートリアルは「画像分類」のアルゴリズムです。 「画像分類」とは、大雑把に言いますと「画像」を投入すると「その画像がなんの画像か」を推測するものです。
例えば、今回実行するチュートリアルですと、下記のような画像を投入すると
Result: label - bathtub, probability - 0.9949665665626526
こんな感じで「この画像に写っているのはバスタブだよ」、「確率は99.49..%だよ」といった結果を返してくれます。 このような学習を「自分のデータでやりたい」という方には、今回ご紹介する「jpg」から「.lst」ファイルを生成して学習する流れが、「とりあえずやってみる」という意味ではいいのではないでしょうか? (AWSは学習に投入するデータフォーマットとして「.lst」よりも「RecordIO」フォーマットを推奨しています。) (「image classfication」では、「.lst」を「RecordIO」に内部的に変換した後に学習を進めます)
2.「.lst」ファイルって何?
恥ずかしながら、私はこのチュートリアルに触れて初めて「.lst」ファイルというものを知りました。 そこで調べてみたところ、「.lst」ファイルとは下記のようなものだとわかりました。
・「.lst」ファイルは「タブ区切り」で「3つのカラム」からなる ・1つ目のカラムは「画像のインデックス」、2つ目のカラムは「画像のクラスのインデックス」、3つ目のカラムは「画像の相対パス」である
実際に画像データを保持するのではなく、あくまでも「どの画像がどのクラスか」といった情報を保持するファイルのようです。 この後チュートリアルを進めればわかるのですが、学習時には「.lst」ファイルと「jpg」の両方を投入します。
「.lst」ファイルの詳細については.lstファイルの説明をご参照ください。
3.実際にやってみた
今回試してみたチュートリアルは「Image-classification-lst-format.ipynb」というファイルです。
ここにあるものと同じです。
基本的に上から順に実行していくだけで大丈夫です。 (S3のバケットだけ適宜修正してください)
3-1.事前準備
トレーニング、各データソースにアクセスする権限を取得します。 また、今回利用するS3バケットを指定します。
%%time import boto3 from sagemaker import get_execution_role from sagemaker.amazon.amazon_estimator import get_image_uri role = get_execution_role() bucket='hogehoge' # customize to your bucket training_image = get_image_uri(boto3.Session().region_name, 'image-classification')
3-2.今回使うデータをダウンロード
今回のチュートリアルで利用するデータをダウンロードします。 「image-classification」では、INPUTデータとして「RecordIO (application/x-recordio)」、「イメージ (application/x-image)」のいずれかが利用可能ですが、今回は「application/x-image」のデータを使ってトレーニングを進めます。 (推論時には「イメージ (application/x-image)」のみがサポートされる点にご注意ください。)
import os import urllib.request def download(url): filename = url.split("/")[-1] if not os.path.exists(filename): urllib.request.urlretrieve(url, filename) # Caltech-256 image files download('http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar') !tar -xf 256_ObjectCategories.tar # Tool for creating lst file download('https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py')
ここでダウンロードしたtarファイルを解凍するとこんな状態です。 こんな感じで階層構造になってます。
/home/ec2-user/SageMaker/imageclassification_caltech_2018-08-08
|--256_ObjectCategories
| |--001.ak47
| | |--001_0001.jpg
| | |--001_0002.jpg
| | |--001_0003.jpg
| |--002.american-flag
| | |--002_0001.jpg
| | |--002_0002.jpg
| | |--002_0003.jpg
上記は下記のコマンドで確認しました。 (カレントディレクトリ配下のファイルの層構造を出力するものですので、適宜「cd」等で移動するか「pwd」部分を「cd」等で対象ディレクトリまで移動してください)
! pwd;find . | sort | sed '1d;s/^\.//;s/\/\([^/]*\)$/|--\1/;s/\/[^/|]*/| /g'
また、ここでは画像データを「.lst」ファイルに変換するツールもダウンロードしています。
3-3.画像データの前処理
「jpg」ファイルから「.lst」ファイルを生成します。 p2.xlargeインスタンスで5分くらいで終わりました。
%%bash mkdir -p caltech_256_train_60 for i in 256_ObjectCategories/*; do c=`basename $i` mkdir -p caltech_256_train_60/$c for j in `ls $i/*.jpg | shuf | head -n 60`; do mv $j caltech_256_train_60/$c/ done done python im2rec.py --list --recursive caltech-256-60-train caltech_256_train_60/ python im2rec.py --list --recursive caltech-256-60-val 256_ObjectCategories/
「train」、「val」用の2つの「.lst」ファイルができました。
「.lst」ファイルの中身を確認してみましょう。
!head -n 3 ./caltech-256-60-train.lst > example.lst f = open('example.lst','r') lst_content = f.read() print(lst_content)
「.lst」ファイルはこんな感じです。(実際に使うものではなく、exampleのものです)
14204 236.000000 237.vcr/237_0070.jpg
14861 247.000000 248.yarmulke/248_0062.jpg
10123 168.000000 169.radio-telescope/169_0061.jpg
「2.「.lst」ファイルって何?」では文字だけでしたが、実物を見ると理解が進みますね。
3-4.「jpg」,「.lst」ファイルをS3にアップロード
続いて、「画像ファイル」と「.lst」ファイルをS3にアップロードします。
# Four channels: train, validation, train_lst, and validation_lst s3train = 's3://{}/train/'.format(bucket) s3validation = 's3://{}/validation/'.format(bucket) s3train_lst = 's3://{}/train_lst/'.format(bucket) s3validation_lst = 's3://{}/validation_lst/'.format(bucket) # upload the image files to train and validation channels !aws s3 cp caltech_256_train_60 $s3train --recursive --quiet !aws s3 cp 256_ObjectCategories $s3validation --recursive --quiet # upload the lst files to train_lst and validation_lst channels !aws s3 cp caltech-256-60-train.lst $s3train_lst --quiet !aws s3 cp caltech-256-60-val.lst $s3validation_lst --quiet
余談ですが、チュートリアル中に下記のようなコメントがありました。
Now we have all the data stored in S3 bucket. The image and lst files will be converted to RecordIO file internelly by the image classification algorithm.
どうやら、「.lst」ファイルは「image classification」アルゴリズムを使っている際は内部的に「RecordIO」に変換されるみたいです。
3-5.トレーニング用のパラメータを指定
トレーニング用のハイパーパラメータを指定します。 ハイパーパラメータの一覧はこちらをご参照ください。
# The algorithm supports multiple network depth (number of layers). They are 18, 34, 50, 101, 152 and 200 # For this training, we will use 18 layers num_layers = 18 # we need to specify the input image shape for the training data image_shape = "3,224,224" # we also need to specify the number of training samples in the training set num_training_samples = 15240 # specify the number of output classes num_classes = 257 # batch size for training mini_batch_size = 128 # number of epochs epochs = 6 # learning rate learning_rate = 0.01 # report top_5 accuracy top_k = 5 # resize image before training resize = 256 # period to store model parameters (in number of epochs), in this case, we will save parameters from epoch 2, 4, and 6 checkpoint_frequency = 2 # Since we are using transfer learning, we set use_pretrained_model to 1 so that weights can be # initialized with pre-trained weights use_pretrained_model = 1
今回はチュートリアルなので、層の深さ、エポック数等が小さめに設定されています。 SageMakerではログから「検証データセットでの精度推移」が確認できるので、とりあえずエポック数を増やして推移を見てみるのもいいかもしれません。
3-6.トレーニング用のパラメータを設定
「3-5.トレーニング用のパラメータを指定」で指定したハイパーパラメータ、学習に使うデータを格納しているS3パス、データフォーマット等を指定します。
「InputDataConfig」の「ContentType」が「application/x-image」なのがミソですね。
%%time import time import boto3 from time import gmtime, strftime s3 = boto3.client('s3') # create unique job name job_name_prefix = 'sagemaker-imageclassification-notebook' timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) job_name = job_name_prefix + timestamp training_params = \ { # specify the training docker image "AlgorithmSpecification": { "TrainingImage": training_image, "TrainingInputMode": "File" }, "RoleArn": role, "OutputDataConfig": { "S3OutputPath": 's3://{}/{}/output'.format(bucket, job_name_prefix) }, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.p2.xlarge", "VolumeSizeInGB": 50 }, "TrainingJobName": job_name, "HyperParameters": { "image_shape": image_shape, "num_layers": str(num_layers), "num_training_samples": str(num_training_samples), "num_classes": str(num_classes), "mini_batch_size": str(mini_batch_size), "epochs": str(epochs), "learning_rate": str(learning_rate), "top_k": str(top_k), "resize": str(resize), "checkpoint_frequency": str(checkpoint_frequency), "use_pretrained_model": str(use_pretrained_model) }, "StoppingCondition": { "MaxRuntimeInSeconds": 360000 }, #Training data should be inside a subdirectory called "train" #Validation data should be inside a subdirectory called "validation" #The algorithm currently only supports fullyreplicated model (where data is copied onto each machine) "InputDataConfig": [ { "ChannelName": "train", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": 's3://{}/train/'.format(bucket), "S3DataDistributionType": "FullyReplicated" } }, "ContentType": "application/x-image", "CompressionType": "None" }, { "ChannelName": "validation", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": 's3://{}/validation/'.format(bucket), "S3DataDistributionType": "FullyReplicated" } }, "ContentType": "application/x-image", "CompressionType": "None" }, { "ChannelName": "train_lst", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": 's3://{}/train_lst/'.format(bucket), "S3DataDistributionType": "FullyReplicated" } }, "ContentType": "application/x-image", "CompressionType": "None" }, { "ChannelName": "validation_lst", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": 's3://{}/validation_lst/'.format(bucket), "S3DataDistributionType": "FullyReplicated" } }, "ContentType": "application/x-image", "CompressionType": "None" } ] } print('Training job name: {}'.format(job_name)) print('\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))
3-7.トレーニングの開始
上記で指定した設定でトレーニングを開始します。 私の場合は、p2.xlargeで23分ほどかかりました。
# create the Amazon SageMaker training job sagemaker = boto3.client(service_name='sagemaker') sagemaker.create_training_job(**training_params) # confirm that the training job has started status = sagemaker.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus'] print('Training job current status: {}'.format(status)) try: # wait for the job to finish and report the ending status sagemaker.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name) training_info = sagemaker.describe_training_job(TrainingJobName=job_name) status = training_info['TrainingJobStatus'] print("Training job ended with status: " + status) except: print('Training failed to start') # if exception is raised, that means it has failed message = sagemaker.describe_training_job(TrainingJobName=job_name)['FailureReason'] print('Training failed with the following error: {}'.format(message))
学習が終わったら、推移をコンソール画面の「ログ」から確認しましょう。 トレーニングデータでの精度推移や検証用データでの精度が確認できるので面白いです。
ログ画面を開き
ここをクリック
対象JOBのログが出てくるのでクリック
検証データ、エポックごとの精度が確認できました!
モニョモニョ精度が上がっていく推移が確認できて面白いですね。
3-8.エンドポイントに使うモデルやロールの設定
エンドポイントに利用するモデル(今回の場合は先ほどS3に出力したモデル)や、ロールを指定します。
%%time import boto3 from time import gmtime, strftime sage = boto3.Session().client(service_name='sagemaker') timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) model_name="image-classification-model" + timestamp print(model_name) info = sage.describe_training_job(TrainingJobName=job_name) model_data = info['ModelArtifacts']['S3ModelArtifacts'] print(model_data) hosting_image = get_image_uri(boto3.Session().region_name, 'image-classification') primary_container = { 'Image': hosting_image, 'ModelDataUrl': model_data, } create_model_response = sage.create_model( ModelName = model_name, ExecutionRoleArn = role, PrimaryContainer = primary_container) print(create_model_response['ModelArn'])
3-9.エンドポイントの設定
立ち上げるエンドポイントの設定を指定します。 具体的には、インスタンスタイプやインスタンスの台数を指定します。 画像分類では、推論用エンドポイントもGPUインスタンスにしないといけません。
from time import gmtime, strftime timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) endpoint_config_name = job_name_prefix + '-epc-' + timestamp endpoint_config_response = sage.create_endpoint_config( EndpointConfigName = endpoint_config_name, ProductionVariants=[{ 'InstanceType':'ml.p2.xlarge', 'InitialInstanceCount':1, 'ModelName':model_name, 'VariantName':'AllTraffic'}]) print('Endpoint configuration name: {}'.format(endpoint_config_name)) print('Endpoint configuration arn: {}'.format(endpoint_config_response['EndpointConfigArn']))
3-10.エンドポイントの作成
「3-8.エンドポイントに使うモデルやロールの設定」、「3-9.エンドポイントの設定」で指定した設定に基づいて、推論用エンドポイントを作成します。
%%time import time timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) endpoint_name = job_name_prefix + '-ep-' + timestamp print('Endpoint name: {}'.format(endpoint_name)) endpoint_params = { 'EndpointName': endpoint_name, 'EndpointConfigName': endpoint_config_name, } endpoint_response = sagemaker.create_endpoint(**endpoint_params) print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))
JOB自体はすぐに完了しますが、エンドポイントが立ち上がるまでは少し時間がかかります。コンソール画面から確認するか、下記を実行してステータスを確認しましょう。
# get the status of the endpoint response = sagemaker.describe_endpoint(EndpointName=endpoint_name) status = response['EndpointStatus'] print('EndpointStatus = {}'.format(status)) try: sagemaker.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name) finally: resp = sagemaker.describe_endpoint(EndpointName=endpoint_name) status = resp['EndpointStatus'] print("Arn: " + resp['EndpointArn']) print("Create endpoint ended with status: " + status) if status != 'InService': message = sagemaker.describe_endpoint(EndpointName=endpoint_name)['FailureReason'] print('Training failed with the following error: {}'.format(message)) raise Exception('Endpoint creation did not succeed')
3-11.実際に推論してみる
作成したエンドポイントに実際に画像を投げてみて、どのように分類されるかを確認してみましょう。
セッションをはって
import boto3 runtime = boto3.Session().client(service_name='runtime.sagemaker')
画像をダウンロードしてみましょう。
!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0007.jpg file_name = '/tmp/test.jpg' # test image from IPython.display import Image Image(file_name)
こんな画像が表示されます。
この画像をエンドポイントに投げてみると、
import json import numpy as np with open(file_name, 'rb') as f: payload = f.read() payload = bytearray(payload) response = runtime.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', Body=payload) result = response['Body'].read() # result will be in json format and convert it to ndarray result = json.loads(result) # the result will output the probabilities for all classes # find the class with maximum probability and print the class index index = np.argmax(result) object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter'] print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
下記の通り、「バスタブである確率、99.49...%」と表示されます。
Result: label - bathtub, probability - 0.9949665665626526
いい感じですね。
3-12.推論用エンドポイントを削除
最後に、推論用エンドポイントを削除します。 画像分類で使う推論用エンドポイントはGPUなので、消し忘れると結構なコストになってしまいますから注意が必要ですね。
sage.delete_endpoint(EndpointName=endpoint_name)
4.まとめ
今回は、「SageMaker」のチュートリアルを通して「jpg」から「.lst」ファイルを作成して「画像分類」をやってみました。 手元にあるデータでとりあえず「画像分類をやってみたい」という場合は、今回ご紹介したチュートリアルを参考にやってみてもいいのではないでしょうか?