SageMaker Python SDKを使った画像分類

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、小澤です。

Amazon SageMaker(以下SageMaker)のビルドインアルゴリズムには、Image Classification(画像分類)があります。 これは、ResNet[1]というDeep Learningを使った手法になっており、とても高い精度での分類が期待できます。

SageMakerではこのImage Classificationの利用方法を解説したexampleが存在してます。

このexampleはCaltech256[2]の学習を行うものになっています。 なお、この他にもImage Classificationではこの他にImageNet[3]を利用した転移学習も対応しており、同じディレクトリにそのexampleも存在しています。

さて、このexampleはboto3を利用した実装になっています。 私としてはどーっしてもscikit-learn[4]ライクなインターフェースが提供されているSageMaker Python SDK[5][6]を使いたかったので、このexampleをそちらに置き換えてみたいと思います。

置き換え内容

さて、では早速どのようなコードに置き換わるのかみていきましょう。

「Prequisites and Preprocessing」で解説されている以下の要素についてはそのまま利用しています。

  • データの保存先となるS3のバケット指定
  • 学習に利用するコンテナイメージの選択
  • 学習に利用するデータの取得とS3への保存

なお、データはあらかじめMXNetで利用可能なrec形式のファイルに変換したものを利用しています。 Caltech256のデータに関しては提供されているものがあるようなのでそれをそのままダウンロードしています。 独自のデータをrec形式に変換するにはMXNetが提供する機能を利用する必要があるようです。

学習とエンドポイントの作成

では、ここからが同様の処理をSageMaker Python SDKを使った処理への置き換えになります。

SageMakerの学習処理の実装内容は以下のような流れになっています。

  1. sagemaker.estimator.Estimatorのインスタンスを作成する
  2. set_hyperparametersでハイパーパラメータの値を設定する
  3. fitで学習処理を実行する

加えて、Image Classificationでは学習時に"train"と"validation"の2種類のデータが必要となるのでその設定も行います。

import sagemaker
session = sagemaker.Session()

resnet = sagemaker.estimator.Estimator(
    training_image,                        # 利用する機械学習手法用のコンテナイメージ
    role,                                  # SageMakerのロール
    train_instance_count=1,                # 学習に利用するインスタンス数
    train_instance_type='ml.p2.xlarge',    # 学習に利用するインスタンスタイプ
    output_path=output_path,               # モデルの出力先のS3のパス
    sagemaker_session=session              # SageMakerのSession
)

Estimatorのインスタンス作成はほぼ定型文ですね。 ちなみに、ここでtrain_instance_type='ml.m4.xlarge'とかにすると「Image ClassificationはGPUインスタンスじゃないとダメだよ」と怒られましたw

続いて、ハイパーパラメータの設定を行います。

resnet.set_hyperparameters(
    num_layers = "18" ,
    image_shape = "3,224,224",
    num_training_samples = "15420",
    num_classes = "257",
    mini_batch_size =  "64",
    epochs = "2",
    learning_rate = "0.01"
)

ここでは、exampleで利用している値をそのまま使っています。 なお、Image Classificationで利用可能なハイパーパラメータについては以下を参照してください。

最後に学習を行います。

from sagemaker.session import s3_input

# Image ClassificationではContent Typeとして"application/x-recordio"か"application/x-image"である必要がある
# 今回学習に利用するのはrecファイルのなので、"application/x-recordio"を指定している
train = s3_input(s3_data='s3://{}/train/'.format(bucket), content_type='application/x-recordio')
valid = s3_input(s3_data='s3://{}/validation/'.format(bucket), content_type='application/x-recordio')

# 学習はscikit-learnライクに学習データを指定してfitを呼び出す
# Image Classificationではtrainとvalidationの2つのデータを学習時にdict形式で指定する
resnet.fit({'train' : train, 'validation' : valid})

予測する

予測は、精度評価の際にはSDKから利用したいところですが「せっかくエンドポイントのAPIが提供されるんだし、curlでやってみよう!」と思い立ちました。

そこで、予測対象のデータはexample内でダウンロードしているものをそのまま使いつつ、やってみたのですが...

curl  \
  -X POST \
  -H "Content-Type: application/x-image" \
  https://runtime.sagemaker.us-east-1.amazonaws.com/endpoints/<エンドポイント名>/invocations \
  -d @/tmp/test.jpg

以下のようなエラーが返ってきました。 認証とな...

{"message":"Missing Authentication Token"}

これに関してフォーラムに以下のような情報がありました。

内容を見ると全く同じことで質問しているものになっています。まさにドンピシャです! 回答を意訳すると「認証が必要だからAWS CLI使うのがオススメだよ!」とのことです。 やってみましょう。

aws sagemaker-runtime invoke-endpoint \
  --endpoint-name <エンドポイント名> \
  --body fileb:///tmp/test.jpg \
  --content-type "application/x-image"  \
  /tmp/output.json

無事、/tmp/output.jsonに結果が出力されました。 出力結果は、分類対象のクラス数分の確率値が記載されたJsonの配列となっているので、あとはargmaxを取ってやるなどしてどのクラスと予測されたかを判断できそうです。

[0.0015657523181289434, 0.0009123011259362102, 0.0002704315702430904, 
...
...
, 0.0005209097289480269, 0.042547550052404404]

おわりに

今回は、SageMakerのexampleに含まれるboto3を使ったImage ClassificationをSageMaker Python SDKに置き換えてみました。 かなり簡単にかける仕組みになっていますね。

Deep Learningはパラメータチューニングが辛かったりするので、実際に利用する際にはここから乗り越えなければならないハードルはまだまだあるかと思います。 それらに関しては転移学習が利用できますし、SageMakerの仕組みを利用すれば複数の異なるパラメータで並列に学習させることも可能ですのでかなりトライアンドエラーを行いやすい環境になっているかと思います。

参考

  1. Deep Residual Learning for Image Recognition
  2. Caltech256
  3. ImageNet
  4. scikit-learn: machine learning in Python — scikit-learn 0.19.1 documentation
  5. GitHub - aws/sagemaker-python-sdk
  6. Amazon SageMaker Python SDK — sagemaker 1.2.4 documentation