この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
こんにちは、大阪DI部の大澤です。
Amazon SageMakerの学習ジョブが完了した際にジョブの情報をSlackのチャンネルに投稿するbotを作ってみました。今回はその内容をご紹介します。
動機
- 自分が回している学習ジョブの完了通知が欲しい
- 他の人がどういう学習ジョブを回しているかを知りたい
仕組み
- 学習ジョブを開始します。
- 学習ジョブが完了し、モデルアーティファクトが
s3://<bucket-name>/<prefix>/<job-name>/output/model.tar.gz
に保存されます。 output/model.tar.gz
という接尾辞の条件を満たすキーのオブジェクトが保存されたことをトリガーとして、Lambda関数が実行されます。- Lambda関数でモデルアーティファクトのオブジェクトキーに含まれる学習ジョブ名
から学習ジョブ情報等を取得&整形し、SlackのIncoming Webhooksの専用URLへPOSTします。 - Slackに学習ジョブ情報が投稿されます。
※ 学習ジョブの完了をLambdaに通知する良い方法が思いつかなかったので、モデルアーティファクトの保存パスのルールを利用しました。なので、学習ジョブが失敗した時などモデルアーティファクトが保存されない場合には対応していません。
※ 全てのS3バケットに対応する訳ではありません。Lambda関数を実行するイベントトリガーを設定したS3バケットに対応します。複数のバケットに対応する場合はそれぞれイベントトリガーの設定が必要になります。
やってみる
Slack
SlackのワークスペースにIncoming Webhooksを追加し、メッセージを送信するために使用するエンドポイントURLを取得します。
Lambda関数の作成
次にマネジメントコンソールを開き、Lambda関数を作成します。 スクラッチを選択し、名前やランタイムや実行ロールを設定します。
- ランタイムはPython3.7にします。
- 実行ロールに今回は3つのポリシー
AWSPriceListServiceFullAccess
、AWSLambdaBasicExecutionRole
、AmazonSageMakerReadOnly
を付けています。- 実際にはもう少し権限は絞れますが、AWS管理ポリシーで楽をしました。
処理の紹介
Lambda関数で行う処理の流れは次のような感じです。
- モデルアーティファクトが保存されたオブジェクトキーから学習ジョブ名を取得します。
<prefix>/<job名>/output/model.tar.gz
という想定
- 学習ジョブ情報等を取得し、投稿用テキストを作成します。
- SageMakerのdescribe_training_jobで学習ジョブの詳細情報を取得し、Pricingのget_productsでインスタンスの価格情報を取得します。
- 取得した情報をもとにステータスや概算費用、学習時間、インスタンスタイプ、入力データなどのデータを整形し、投稿用テキストを作成します。
- Slack Incoming Webhooksの専用のエンドポイントURLに投稿データをPOSTします。
- 投稿データには学習ジョブ情報を記載したメッセージテキスト、通知先チャンネル、投稿時の表示名が含まれます。
以下がそのスクリプトになります。
import os
import json
import urllib
import boto3
from datetime import timezone, timedelta
# パラメータを読み込む
BOT_USERNAME = os.environ['BOT_USERNAME']
SLACK_URL = os.environ['SLACK_URL']
SLACK_CHANNEL = os.environ['SLACK_CHANNEL']
# インスタンスタイプとリージョンに応じた価格取得時に指定する地域名(雑だけどもいい方法がないので、とりあえず辞書で対応)
REGION_MAP = {
'us-gov-west-1' : 'AWS GovCloud (US)',
'ap-south-1' : 'Asia Pacific (Mumbai)',
'ap-northeast-2' : 'Asia Pacific (Seoul)',
'ap-southeast-1' : 'Asia Pacific (Singapore)',
'ap-southeast-2' : 'Asia Pacific (Sydney)',
'ap-northeast-1' : 'Asia Pacific (Tokyo)',
'ca-central-1' : 'Canada (Central)',
'eu-central-1' : 'EU (Frankfurt)',
'eu-west-1' : 'EU (Ireland)',
'eu-west-2' : 'EU (London)',
'us-east-1' : 'US East (N. Virginia)',
'us-east-2' : 'US East (Ohio)',
'us-west-1' : 'US West (N. California)',
'us-west-2' : 'US West (Oregon)',
}
# JSTを定義する
JST = timezone(timedelta(hours=9), 'JST')
def lambda_handler(event, context):
"""Lambdaから呼ばれるコールバック関数"""
object_key = event['Records'][0]['s3']['object']['key']
# "<prefix>/<job名>/output/model.tar.gz"となる想定で、学習ジョブ名を取得する
job_name = object_key.split('/')[-3]
# バケットのあるリージョン
region = event['Records'][0]['awsRegion']
# ジョブ情報を作成
job_info = create_training_job_info(job_name, region)
# Slackにメッセージを送信する
send_slack_message(job_info)
return {
'statusCode': 200,
'body': json.dumps('Hello from Lambda!')
}
def send_slack_message(text):
"""Slackにメッセージを送信する"""
data = {
'username':BOT_USERNAME, # 表示名
'text':text, # 内容
'channel': SLACK_CHANNEL # 送信先チャンネル
}
method = "POST"
headers = {"Content-Type" : "application/json"}
req = urllib.request.Request(SLACK_URL, method=method, data=json.dumps(data).encode(), headers=headers)
with urllib.request.urlopen(req) as res:
body = res.read()
return body
def create_training_job_info(job_name, region):
"""ジョブ名とリージョンから学習ジョブ情報を作成する"""
sm = boto3.client('sagemaker', region_name=region)
# ジョブデータを取得する
job_data = sm.describe_training_job(TrainingJobName=job_name)
# ジョブ状態
job_status = job_data['TrainingJobStatus']
# インスタンスタイプ
instance_type = job_data['ResourceConfig']['InstanceType']
# ジョブ実行時間
job_total_seconds = (job_data['TrainingEndTime'] - job_data['TrainingStartTime']).total_seconds()
job_time_text = create_time_text(job_total_seconds)
# インスタンス価格を取得する
instance_price = get_instance_price(region, instance_type)
# メトリクス(lambdaのbotocoreのバージョンが古いため、未対応。今後に期待)
metric_data = create_metric_text(job_data['FinalMetricDataList']) if 'FinalMetricDataList' in job_data else 'データなし'
# 入力データ
input_data = create_input_data_text(job_data['InputDataConfig'])
# 学習コストの概算
job_cost = job_total_seconds / 3600 * float(instance_price)
job_info = '''
ジョブ名: {job_name}
ステータス: {job_status}
概算費用(USD): ${job_cost:,.3f}
学習時間: {job_time}({job_start_time}〜{job_end_time})
インスタンスタイプ: {instance_type} ({instance_price:,.3f} USD/h)
モデルアーティファクト: {model_artifacts_uri}
入力データ:
{input_data}
メトリクス(未対応):
{metric_data}
ジョブ詳細: <https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}>
'''.format(
region = region,
job_status = job_status,
job_cost = job_cost,
job_name = job_name,
instance_type = instance_type,
instance_price = float(instance_price),
job_start_time = job_data['TrainingStartTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'),
job_end_time = job_data['TrainingEndTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'),
job_time = job_time_text,
input_data = input_data,
model_artifacts_uri = job_data['ModelArtifacts']['S3ModelArtifacts'],
metric_data = metric_data
)
return job_info
def get_first_element(dict_obj):
"""辞書の最初の要素を取得する"""
return next(iter(dict_obj.values()))
def get_instance_price(region, instance_type):
"""リージョンとインスタンスタイプに応じた価格を取得する"""
# 対応していないリージョンがあるため、決め打ちでus-east-1を使う
pricing = boto3.client('pricing', region_name='us-east-1')
# 対応するリージョンとインスタンスタイプの料金を取得する
response = pricing.get_products(
ServiceCode='AmazonSageMaker',
Filters=[
{
'Type': 'TERM_MATCH',
'Field': 'location',
'Value': REGION_MAP[region]
},
{
'Type': 'TERM_MATCH',
'Field': 'instanceType',
'Value': instance_type + '-Training'
}
]
)
return get_first_element(get_first_element(json.loads(response['PriceList'][0])['terms']['OnDemand'])['priceDimensions'])['pricePerUnit']['USD']
def create_time_text(total_seconds):
"""秒数から読みやすい時間表記を作る"""
days, remainder = divmod(total_seconds, 86400)
hours, remainder = divmod(remainder, 3600)
minutes, seconds = divmod(remainder, 60)
time_text = ''
if days > 0:
time_text += str(int(days))+'日'
if hours > 0:
time_text += str(int(hours))+'時間'
if minutes > 0:
time_text += str(int(minutes))+'分'
time_text += str(int(seconds))+'秒'
return time_text
def create_metric_text(metric_list):
"""メトリクスデータを作成する"""
text = ''
for metric_data in metric_list:
text += ' {metric_name}: {metric_value}\n'.format(
metric_name = metric_data['MetricName'],
metric_value = metric_data['Value']
)
return text
def create_input_data_text(input_data_config):
"""入力データ情報を作成する"""
if len(input_data_config) == 0:
# 強化学習の場合などはデータが無い
return 'データなし'
text = ''
for input_data in input_data_config:
text += ' {channel_name}: {s3uri}\n'.format(
channel_name = input_data['ChannelName'],
s3uri = input_data['DataSource']['S3DataSource']['S3Uri']
)
return text
※ スクリプト中に学習の最後のメトリクスデータを表示できるような処理を含んでいますが、機能しません。Lambdaのランタイム環境のbotocoreのバージョンが低いため、describe_training_jobがFinalMetricDataListに対応していなかったためです。そのうち更新されるだろうという希望のもと、参考程度に残しています。(2019/1/22現在)
環境変数の設定
スクリプトで使用するパラメータを環境変数として設定します。
- BOT_USERNAME: Slackへ投稿時に表示される名前です。
- SLACK_CHANNEL: 投稿先チャンネル名です。
- SLACK_URL: Incoming Webhooks設定時に取得したエンドポイントURLです。このURLにデータをPOSTすることでSlackに投稿することができます。
トリガーの設定
次にLambda関数のトリガーを設定します。
S3の特定のバケットに接尾辞がoutput/model.tar.gz
のオブジェクトが保存されるとLambda関数が実行されるように設定し、イベントとして追加します。イベントの追加は対応させるバケットの数だけ必要になります。
イベントの追加が終われば必要な設定は終わりです。右上のSaveでLambda関数を保存します。
動作確認
これまでに動かした学習ジョブのモデルアーティファクトをダウンロードし、同じ場所へアップロードし直すと学習時を再現できます。 実際にやってみると、こんな感じでSlackへ投稿されます。
情報が不足していたり、多かったり、通知内容が見にくかったりすると思うので、必要に応じてスクリプトを修正してください〜。
さいごに
Amazon SageMakerで学習ジョブが完了した際に学習ジョブ情報をSlackに通知するbotを作ってみました。少々強引な方法になってしまいましたが、それっぽいものは出来ました。これでマネジメントコンソールに入らなくても、どういうジョブが行われたかが分かります。Slackへの通知部分を書き換えることで別のサービスとの連携も可能です。使い方によっては便利なものになるかもしれないですね。
お読みくださり、ありがとうございました〜!