Amazon SageMakerの予測APIをboto3から叩く

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、小澤です。

Amazon SageMaker(以下SageMaker)では、学習したモデルに対して予測用のエンドポイントの作成まで行ってくれます。 このエンドポイントへのアクセスはHTTPのPOSTでのやり取りとなるのですが、Signature Version 4での認証が必要となります。

Pythonでのサンプルコードも記載されてはいるのですが、認証の仕組みに詳しくない私みたいな人間にとってはなかなかハードルが高いですね。。 boto3を使うことでAPIキーやIAMロールを使った認証をいい感じにラップしてくれているので、そちらの方法でAPIにアクセスして見たいと思います。

やってみる

やってみる、とは言っても非常に簡単です。 今回はアクセスすることが目的なので、すでにエンドポイントの作成までは終わってるものとします。

boto3を使ってSageMakerのAPIにアクセスするためのメソッドはinvoke_endpointになります。 sagemaker-runtimeのクライアントを取得して、以下のように実行します。

import boto3
client = boto3.client('sagemaker-runtime')

response = client.invoke_endpoint(
    EndpointName='<your-endpoint-name>',
    Body=b'5.1,3.5,1.4,0.2',
    ContentType='text/csv',
    Accept='application/json'
)

ここでは、CSV形式で4つの特徴量を持つデータに対する予測を行っています。 Bodyとして渡すデータやContentTypeにしているMIMEは利用するエンドポントに応じて適切なものを選択してやってください。

EndpointNameに関しては、エンドポイントのURLではなく、エンドポイント名を指定します。

結果は以下のように返ってきます。

{'Body': <botocore.response.StreamingBody at 0x104762550>,
 'ContentType': 'application/json',
 'InvokedProductionVariant': 'default-variant-name',
 'ResponseMetadata': {'HTTPHeaders': {'content-length': '73',
   'content-type': 'application/json',
   'date': 'Thu, 12 Jul 2018 02:35:32 GMT',
   'x-amzn-invoked-production-variant': 'default-variant-name',
   'x-amzn-requestid': '8948ed99-4f62-4dd2-a152-eb2e661398f7'},
  'HTTPStatusCode': 200,
  'RequestId': '8948ed99-4f62-4dd2-a152-eb2e661398f7',
  'RetryAttempts': 0}}

JSONとして返ってきた結果がdictに格納されています。 この中で実際に予測結果が入っているのがBodyとなりますので取り出して見ましょう。

import json
body = response['Body']
json.load(body)

以下のように予測ラベルとその確率が得られます。

{'predictions': [{'predicted_label': 0.0, 'score': 0.07713162899017334}]}

おわりに

今回はSageMakerのエンドポイントに対してboto3を使って予測を行ってみました。 実際に予測結果を利用するシステムにSageMakerで学習させたモデルを使うのが非常に簡単であることが確認できたかと思います。