Amazon SageMakerで画像分類をしてみた
概要
こんにちは、yoshimです。 今回はSageMakerでビルトインアルゴリズムとして実装されている「Image classification transfer learning demo」について、チュートリアルを実際にやってみます。 「転移学習」をすることで、「優秀な画像分類モデルを手軽に実装」できました。
目次
1.画像分類とは
まず、そもそも「画像分類」ってどのようなことをするのでしょうか? 大雑把に言いますと「画像」を投入すると「その画像がなんの画像か」を推測するものです。
例えば、今回実行するチュートリアルですと、下記のような画像を投入すると
Result: label - umbrella-101, probability - 0.9996694326400757
こんな感じで「この画像に写っているのは傘だよ」、「確率は99.9..%だよ」といった結果を返してくれます。 最近は「自動運転技術」や「工場での不良品検出」等、様々なところで画像認識技術が研究されていますが、その一端に触れることができるのでワクワクしますね。
余談ですが、このビルトインアルゴリズムでは「ResNet」を使います。 「ResNet」はobject detectionでもbaseで使うことができます。(こちらでは層の深さは50)
2.転移学習とは
ニューラルネットワークベースのアルゴリズムはパラメータ数がとても多いため、チューニングするのがとても大変です。 「転移学習」は「ハイパーパラメータを1から学習する」のではなく、既にチューニングされているパラメータを利用して学習の手間を減らすという、とても素敵な技術なのです。
とはいえ、「転移学習を使えば新しく学習することは全くない」という訳ではないのでご注意ください。
3.実際にやってみた
今回試してみたチュートリアルは「Image-classification-transfer-learning.ipynb」というファイルです。
ここにあるものと同じです。
基本的に上から順に実行していくだけで大丈夫です。
3-1.事前準備
トレーニング、各データソースにアクセスする権限を取得します。 また、今回利用するS3バケットを指定します。
%%time import boto3 import re from sagemaker import get_execution_role from sagemaker.amazon.amazon_estimator import get_image_uri role = get_execution_role() bucket='<<bucket-name>>' # customize to your bucket training_image = get_image_uri(boto3.Session().region_name, 'image-classification') # ビルトインアルゴリズムである「image-classification」を指定 print(training_image)
3-2.今回使うデータをダウンロード
今回のチュートリアルで利用するデータをダウンロードします。 「image-classification」では、INPUTデータとして「RecordIO (application/x-recordio)」、「イメージ (application/x-image)」のいずれかが利用可能ですが、今回は「recordIO」のデータを使ってトレーニングを進めます。 ただし、推論時には「イメージ (application/x-image)」のみがサポートされる点にご注意ください。
import os import urllib.request import boto3 def download(url): filename = url.split("/")[-1] if not os.path.exists(filename): urllib.request.urlretrieve(url, filename) def upload_to_s3(channel, file): s3 = boto3.resource('s3') data = open(file, "rb") key = channel + '/' + file s3.Bucket(bucket).put_object(Key=key, Body=data) # # caltech-256 download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec') download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec') upload_to_s3('validation', 'caltech-256-60-val.rec') upload_to_s3('train', 'caltech-256-60-train.rec')
3-3.トレーニング用のパラメータを指定
トレーニング用のハイパーパラメータを指定します。 ハイパーパラメータの一覧はこちらをご参照ください。 ハイパーパラメータのチューニングについてのドキュメントをみると、もっともチューニングの効果が出そうなのは「mini_batch_size」、「learning_rate」、「optimizer」の3つらしいです。
%%time # The algorithm supports multiple network depth (number of layers). They are 18, 34, 50, 101, 152 and 200 # For this training, we will use 18 layers num_layers = 18 # we need to specify the input image shape for the training data image_shape = "3,224,224" # we also need to specify the number of training samples in the training set # for caltech it is 15420 num_training_samples = 15420 # specify the number of output classes num_classes = 257 # batch size for training mini_batch_size = 128 # number of epochs epochs = 2 # learning rate learning_rate = 0.01 top_k=2 # Since we are using transfer learning, we set use_pretrained_model to 1 so that weights can be # initialized with pre-trained weights use_pretrained_model = 1
今回はチュートリアルなので、層の深さ、エポック数等が最小限に設定されています。 SageMakerではログから「検証データセットでの精度推移」が確認できるので、とりあえずエポック数を増やして推移を見てみるのもいいかもしれません。
3-4.トレーニング用のパラメータを設定
「3-3.トレーニング用のパラメータを指定」で指定したハイパーパラメータ、学習に使うデータを格納しているS3パス、データフォーマット等を指定します。 このチュートリアルのコードをそのまま使うと、指定したバケットの直下にいきなり「bucket/train」とか「bucket/test」といったパスを作ることになってしまうので少し気になりました。
もし嫌だという場合は「prefix」とかを変数として指定して、下記のように変更するといいです。
"S3Uri": 's3://{}/train/'.format(bucket) を "S3Uri": 's3://{0}/{1}/train/'.format(bucket,prefix)
"S3Uri": 's3://{}/validation/'.format(bucket), を "S3Uri": 's3://{0}/{1}/validation/'.format(bucket,prefix),
%%time import time import boto3 from time import gmtime, strftime # prefix = 'hogehoge' # 出力先S3パスを変更するなら指定 s3 = boto3.client('s3') # create unique job name job_name_prefix = 'DEMO-imageclassification' timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) job_name = job_name_prefix + timestamp training_params = \ { # specify the training docker image "AlgorithmSpecification": { "TrainingImage": training_image, "TrainingInputMode": "File" }, "RoleArn": role, "OutputDataConfig": { "S3OutputPath": 's3://{}/{}/output'.format(bucket, job_name_prefix) }, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.p2.xlarge", "VolumeSizeInGB": 50 }, "TrainingJobName": job_name, "HyperParameters": { "image_shape": image_shape, "num_layers": str(num_layers), "num_training_samples": str(num_training_samples), "num_classes": str(num_classes), "mini_batch_size": str(mini_batch_size), "epochs": str(epochs), "learning_rate": str(learning_rate), "use_pretrained_model": str(use_pretrained_model) }, "StoppingCondition": { "MaxRuntimeInSeconds": 360000 }, #Training data should be inside a subdirectory called "train" #Validation data should be inside a subdirectory called "validation" #The algorithm currently only supports fullyreplicated model (where data is copied onto each machine) "InputDataConfig": [ { "ChannelName": "train", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": 's3://{}/train/'.format(bucket), "S3DataDistributionType": "FullyReplicated" } }, "ContentType": "application/x-recordio", "CompressionType": "None" }, { "ChannelName": "validation", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": 's3://{}/validation/'.format(bucket), "S3DataDistributionType": "FullyReplicated" } }, "ContentType": "application/x-recordio", "CompressionType": "None" } ] } print('Training job name: {}'.format(job_name)) print('\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))
3-5.トレーニングの開始
上記で指定した設定でトレーニングを開始します。 私の場合は、p2.xlargeで10分ほどかかりました。 注意点として、下記のJOBをjupyterの画面上から停止しても裏ではまだ実行されている場合があります。 その場合は、コンソール画面の「トレーニングジョブ」から対象のJOBを停止してください。
# create the Amazon SageMaker training job sagemaker = boto3.client(service_name='sagemaker') sagemaker.create_training_job(**training_params) # confirm that the training job has started status = sagemaker.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus'] print('Training job current status: {}'.format(status)) try: # wait for the job to finish and report the ending status sagemaker.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name) training_info = sagemaker.describe_training_job(TrainingJobName=job_name) status = training_info['TrainingJobStatus'] print("Training job ended with status: " + status) except: print('Training failed to start') # if exception is raised, that means it has failed message = sagemaker.describe_training_job(TrainingJobName=job_name)['FailureReason'] print('Training failed with the following error: {}'.format(message))
また、学習の推移をコンソール画面の「ログ」から確認できます。 トレーニングデータでの精度推移が確認できるので、見ておくといいと思います。 また、最後には検証用データでの精度も計算してくれています。
ログ画面を開き
ここをクリック
対象JOBのログが出てくるのでクリック
検証データ、エポックごとの精度が確認できました!
3-6.エンドポイントに使うモデルやロールの設定
エンドポイントに利用するモデル(今回の場合は先ほどS3に出力したモデル)や、ロールを指定します。
%%time import boto3 from time import gmtime, strftime sage = boto3.Session().client(service_name='sagemaker') model_name="DEMO-image-classification-model" print(model_name) info = sage.describe_training_job(TrainingJobName=job_name) model_data = info['ModelArtifacts']['S3ModelArtifacts'] print(model_data) hosting_image = get_image_uri(boto3.Session().region_name, 'image-classification') primary_container = { 'Image': hosting_image, 'ModelDataUrl': model_data, } create_model_response = sage.create_model( ModelName = model_name, ExecutionRoleArn = role, PrimaryContainer = primary_container) print(create_model_response['ModelArn'])
3-7.エンドポイントの設定
立ち上げるエンドポイントの設定を指定します。 具体的には、インスタンスタイプやインスタンスの台数を指定します。 A/Bテストをする場合は「ModelName」もしっかり指定してあげた方がテストがしやすいかと思います。
%%time from time import gmtime, strftime timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) endpoint_config_name = job_name_prefix + '-epc-' + timestamp endpoint_config_response = sage.create_endpoint_config( EndpointConfigName = endpoint_config_name, ProductionVariants=[{ 'InstanceType':'ml.m4.xlarge', 'InitialInstanceCount':1, 'ModelName':model_name, 'VariantName':'AllTraffic'}]) print('Endpoint configuration name: {}'.format(endpoint_config_name)) print('Endpoint configuration arn: {}'.format(endpoint_config_response['EndpointConfigArn']))
3-8.エンドポイントの作成
「3-6.エンドポイントに使うモデルやロールの設定」、「3-7.エンドポイントの設定」で指定した設定に基づいて、推論用エンドポイントを作成します。
%%time import time timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime()) endpoint_name = job_name_prefix + '-ep-' + timestamp print('Endpoint name: {}'.format(endpoint_name)) endpoint_params = { 'EndpointName': endpoint_name, 'EndpointConfigName': endpoint_config_name, } endpoint_response = sagemaker.create_endpoint(**endpoint_params) print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))
JOB自体はすぐに完了しますが、エンドポイントが立ち上がるまでは少し時間がかかります。コンソール画面から確認するか、下記を実行してステータスを確認しましょう。
# get the status of the endpoint response = sagemaker.describe_endpoint(EndpointName=endpoint_name) status = response['EndpointStatus'] print('EndpointStatus = {}'.format(status)) # wait until the status has changed sagemaker.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name) # print the status of the endpoint endpoint_response = sagemaker.describe_endpoint(EndpointName=endpoint_name) status = endpoint_response['EndpointStatus'] print('Endpoint creation ended with EndpointStatus = {}'.format(status)) if status != 'InService': raise Exception('Endpoint creation failed.')
3-9.実際に推論してみる
作成したエンドポイントに実際に画像を投げてみて、どのように分類されるかを確認してみましょう。
セッションをはって
import boto3 runtime = boto3.Session().client(service_name='runtime.sagemaker')
画像をダウンロードしてみましょう。
!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0007.jpg file_name = '/tmp/test.jpg' # test image from IPython.display import Image Image(file_name)
こんな画像が表示されます。
この画像をエンドポイントに投げてみると、
import json import numpy as np with open(file_name, 'rb') as f: payload = f.read() payload = bytearray(payload) response = runtime.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', Body=payload) result = response['Body'].read() # result will be in json format and convert it to ndarray result = json.loads(result) # the result will output the probabilities for all classes # find the class with maximum probability and print the class index index = np.argmax(result) object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter'] print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
下記の通り、「バスタブである確率、89.7%」と表示されます。
Result: label - bathtub, probability - 0.897666871547699
試しに適当な画像を投入してみようと思います。 ブラウザ画面上から画像をアップロードした後に、
# 適当に画像をアップロードして試してみる file_name = '/home/ec2-user/SageMaker/blog/image_classification_transfer_learning/imageclassification_caltech_2018-08-06/image_0001.jpg'# test image from IPython.display import Image Image(file_name)
この傘の画像を投入してみます。
import json import numpy as np with open(file_name, 'rb') as f: payload = f.read() payload = bytearray(payload) response = runtime.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', Body=payload) result = response['Body'].read() # result will be in json format and convert it to ndarray result = json.loads(result) # the result will output the probabilities for all classes # find the class with maximum probability and print the class index index = np.argmax(result) print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
ちゃんと傘だと認識しました。
Result: label - umbrella-101, probability - 0.9996694326400757
なお、今回学習したモデルで分類できるのは「object_categories」に列挙した物体だけ(256種)です。 (バスタブも傘もちゃんと「object_categories」に列挙されています。) 今回使ったデータセットでは、この256種類+「その他」といったくくりの全257種類の画像で学習をしていたのですが、推論のタイミングでは「その他」は利用しません。
とりあえず、256種類に含まれない画像を入れたらどんな感じかを確かめてみましょう。
file_name = '/home/ec2-user/SageMaker/blog/image_classification_transfer_learning/imageclassification_caltech_2018-08-06/image2.png'# test image from IPython.display import Image Image(file_name)
このアメフトボールで試してみます。
import json import numpy as np with open(file_name, 'rb') as f: payload = f.read() payload = bytearray(payload) response = runtime.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', Body=payload) result = response['Body'].read() # result will be in json format and convert it to ndarray result = json.loads(result) # the result will output the probabilities for all classes # find the class with maximum probability and print the class index index = np.argmax(result) print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
Result: label - computer-mouse, probability - 0.173785001039505
コンピュータのマウス、と推測されてしまいました。 ただ、確率は17.3%程度とのことなので実際に利用する場合には切り捨てることになりそうですね。 (実際に利用する場合は、APIを受け取る側で閾値を設定して異なる挙動を取らせることになるかと思います) 折角なので、一番確率が高いラベルだけでなく、「上から5ラベル」くらい表示させてみましょう。
%%time '''original 上記では一番確率が高いクラスを返していたけど、トップ5くらいまで返してみよう。 ''' import json import numpy as np with open(file_name, 'rb') as f: payload = f.read() payload = bytearray(payload) response = runtime.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', Body=payload) result = response['Body'].read() # result will be in json format and convert it to ndarray result = json.loads(result) # the result will output the probabilities for all classes # find the class with maximum probability and print the class index indexs = np.argsort(result)[::-1][:5] # ここを修正 # top5を表示 for index in indexs: print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
Result: label - computer-mouse, probability - 0.173785001039505 Result: label - tennis-ball, probability - 0.07102835923433304 Result: label - soccer-ball, probability - 0.05962461978197098 Result: label - yo-yo, probability - 0.05829168111085892 Result: label - frisbee, probability - 0.04761170595884323 CPU times: user 16 ms, sys: 0 ns, total: 16 ms Wall time: 144 ms
コンピューターマウス、テニスボール、サッカーボール、ヨーヨー、フリスビー、という結果でした...。
楕円、もしくは丸いもの、ってことで似ていると判断しているのかもしれませんね...。
4.まとめ
転移学習を使うことで、1から学習せずに手早く優秀なモデルを利用することができます。 (美味い、安い、早い、牛丼みたいです) とはいえトレーニングやパラメータのチューニングが不要、という訳ではないのでご注意ください。
チューニングするべきパラメータは下記をご参照ください。 ハイパーパラメータ
「augmentation_type」がなんとなく便利そうでいいですね。