Amazon SageMakerの学習ジョブ完了通知bot(Slack)を作ってみた

こんにちは、大阪DI部の大澤です。

Amazon SageMakerの学習ジョブが完了した際にジョブの情報をSlackのチャンネルに投稿するbotを作ってみました。今回はその内容をご紹介します。

動機

  • 自分が回している学習ジョブの完了通知が欲しい
  • 他の人がどういう学習ジョブを回しているかを知りたい

仕組み

  1. 学習ジョブを開始します。
  2. 学習ジョブが完了し、モデルアーティファクトがs3://<bucket-name>/<prefix>/<job-name>/output/model.tar.gzに保存されます。
  3. output/model.tar.gzという接尾辞の条件を満たすキーのオブジェクトが保存されたことをトリガーとして、Lambda関数が実行されます。
  4. Lambda関数でモデルアーティファクトのオブジェクトキーに含まれる学習ジョブ名から学習ジョブ情報等を取得&整形し、SlackのIncoming Webhooksの専用URLへPOSTします。
  5. Slackに学習ジョブ情報が投稿されます。

※ 学習ジョブの完了をLambdaに通知する良い方法が思いつかなかったので、モデルアーティファクトの保存パスのルールを利用しました。なので、学習ジョブが失敗した時などモデルアーティファクトが保存されない場合には対応していません。
※ 全てのS3バケットに対応する訳ではありません。Lambda関数を実行するイベントトリガーを設定したS3バケットに対応します。複数のバケットに対応する場合はそれぞれイベントトリガーの設定が必要になります。

やってみる

Slack

SlackのワークスペースにIncoming Webhooksを追加し、メッセージを送信するために使用するエンドポイントURLを取得します。

Lambda関数の作成

次にマネジメントコンソールを開き、Lambda関数を作成します。 スクラッチを選択し、名前やランタイムや実行ロールを設定します。

  • ランタイムはPython3.7にします。
  • 実行ロールに今回は3つのポリシーAWSPriceListServiceFullAccessAWSLambdaBasicExecutionRoleAmazonSageMakerReadOnlyを付けています。
    • 実際にはもう少し権限は絞れますが、AWS管理ポリシーで楽をしました。

処理の紹介

Lambda関数で行う処理の流れは次のような感じです。

  • モデルアーティファクトが保存されたオブジェクトキーから学習ジョブ名を取得します。
    • <prefix>/<job名>/output/model.tar.gzという想定
  • 学習ジョブ情報等を取得し、投稿用テキストを作成します。
    • SageMakerのdescribe_training_jobで学習ジョブの詳細情報を取得し、Pricingのget_productsでインスタンスの価格情報を取得します。
    • 取得した情報をもとにステータスや概算費用、学習時間、インスタンスタイプ、入力データなどのデータを整形し、投稿用テキストを作成します。
  • Slack Incoming Webhooksの専用のエンドポイントURLに投稿データをPOSTします。
    • 投稿データには学習ジョブ情報を記載したメッセージテキスト、通知先チャンネル、投稿時の表示名が含まれます。

以下がそのスクリプトになります。

import os
import json
import urllib
import boto3
from datetime import timezone, timedelta

# パラメータを読み込む
BOT_USERNAME = os.environ['BOT_USERNAME'] 
SLACK_URL = os.environ['SLACK_URL'] 
SLACK_CHANNEL = os.environ['SLACK_CHANNEL'] 

# インスタンスタイプとリージョンに応じた価格取得時に指定する地域名(雑だけどもいい方法がないので、とりあえず辞書で対応)
REGION_MAP = {
    'us-gov-west-1' : 'AWS GovCloud (US)',
    'ap-south-1' : 'Asia Pacific (Mumbai)',
    'ap-northeast-2' : 'Asia Pacific (Seoul)',
    'ap-southeast-1' : 'Asia Pacific (Singapore)',
    'ap-southeast-2' : 'Asia Pacific (Sydney)',
    'ap-northeast-1' : 'Asia Pacific (Tokyo)',
    'ca-central-1' : 'Canada (Central)',
    'eu-central-1' : 'EU (Frankfurt)',
    'eu-west-1' : 'EU (Ireland)',
    'eu-west-2' : 'EU (London)',
    'us-east-1' : 'US East (N. Virginia)',
    'us-east-2' : 'US East (Ohio)',
    'us-west-1' : 'US West (N. California)',
    'us-west-2' : 'US West (Oregon)',
}

# JSTを定義する
JST = timezone(timedelta(hours=9), 'JST')


def lambda_handler(event, context):
    """Lambdaから呼ばれるコールバック関数"""
    object_key = event['Records'][0]['s3']['object']['key']
    
    # "<prefix>/<job名>/output/model.tar.gz"となる想定で、学習ジョブ名を取得する
    job_name = object_key.split('/')[-3]
    
    # バケットのあるリージョン
    region = event['Records'][0]['awsRegion']
    
    # ジョブ情報を作成
    job_info = create_training_job_info(job_name, region)
    
    # Slackにメッセージを送信する
    send_slack_message(job_info)
    
    return {
        'statusCode': 200,
        'body': json.dumps('Hello from Lambda!')
    }

def send_slack_message(text):
    """Slackにメッセージを送信する"""
    data = {
        'username':BOT_USERNAME, # 表示名
        'text':text, # 内容
        'channel': SLACK_CHANNEL # 送信先チャンネル
    }
    method = "POST"
    headers = {"Content-Type" : "application/json"}
    req = urllib.request.Request(SLACK_URL, method=method, data=json.dumps(data).encode(), headers=headers)
    with urllib.request.urlopen(req) as res:
        body = res.read()

    return body


def create_training_job_info(job_name, region):
    """ジョブ名とリージョンから学習ジョブ情報を作成する"""
    sm = boto3.client('sagemaker', region_name=region)
    
    # ジョブデータを取得する
    job_data = sm.describe_training_job(TrainingJobName=job_name)
    
    # ジョブ状態
    job_status = job_data['TrainingJobStatus']
    
    # インスタンスタイプ
    instance_type = job_data['ResourceConfig']['InstanceType']
    
    # ジョブ実行時間
    job_total_seconds = (job_data['TrainingEndTime'] - job_data['TrainingStartTime']).total_seconds()
    job_time_text = create_time_text(job_total_seconds)
    
    # インスタンス価格を取得する
    instance_price = get_instance_price(region, instance_type)
    
    # メトリクス(lambdaのbotocoreのバージョンが古いため、未対応。今後に期待)
    metric_data = create_metric_text(job_data['FinalMetricDataList']) if 'FinalMetricDataList' in job_data else 'データなし'

    # 入力データ
    input_data = create_input_data_text(job_data['InputDataConfig'])

    
    # 学習コストの概算
    job_cost = job_total_seconds / 3600 * float(instance_price)
    
    job_info = '''
ジョブ名: {job_name}
ステータス: {job_status}
概算費用(USD): ${job_cost:,.3f}
学習時間: {job_time}({job_start_time}〜{job_end_time})
インスタンスタイプ: {instance_type} ({instance_price:,.3f} USD/h)
モデルアーティファクト: {model_artifacts_uri}

入力データ:
{input_data}

メトリクス(未対応):
{metric_data}

ジョブ詳細: <https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}>
    '''.format(
        region = region,
        job_status = job_status,
        job_cost = job_cost,
        job_name = job_name,
        instance_type = instance_type,
        instance_price = float(instance_price),
        job_start_time = job_data['TrainingStartTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'),
        job_end_time = job_data['TrainingEndTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'),
        job_time = job_time_text,
        input_data = input_data,
        model_artifacts_uri = job_data['ModelArtifacts']['S3ModelArtifacts'],
        metric_data = metric_data
    )
    return job_info


def get_first_element(dict_obj):
    """辞書の最初の要素を取得する"""
    
    return next(iter(dict_obj.values()))
    
def get_instance_price(region, instance_type):
    """リージョンとインスタンスタイプに応じた価格を取得する"""
    
    # 対応していないリージョンがあるため、決め打ちでus-east-1を使う
    pricing = boto3.client('pricing', region_name='us-east-1')
    

    # 対応するリージョンとインスタンスタイプの料金を取得する
    response = pricing.get_products(
        ServiceCode='AmazonSageMaker',
        Filters=[
            {
                'Type': 'TERM_MATCH',
                'Field': 'location',
                'Value': REGION_MAP[region]
            },
            {
                'Type': 'TERM_MATCH',
                'Field': 'instanceType',
                'Value': instance_type + '-Training'
            }
        ]
    )
    return get_first_element(get_first_element(json.loads(response['PriceList'][0])['terms']['OnDemand'])['priceDimensions'])['pricePerUnit']['USD']

def create_time_text(total_seconds):
    """秒数から読みやすい時間表記を作る"""

    days, remainder = divmod(total_seconds, 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, seconds = divmod(remainder, 60)
    time_text = ''
    if days > 0:
        time_text += str(int(days))+'日'
    if hours > 0:
        time_text += str(int(hours))+'時間'
    if minutes > 0:
        time_text += str(int(minutes))+'分'

    time_text += str(int(seconds))+'秒'
    
    return time_text

def create_metric_text(metric_list):
    """メトリクスデータを作成する"""
    text = ''
    for metric_data in metric_list:
        text += '  {metric_name}: {metric_value}\n'.format(
            metric_name = metric_data['MetricName'],
            metric_value = metric_data['Value']
        )
    return text

def create_input_data_text(input_data_config):
    """入力データ情報を作成する"""

    if len(input_data_config) == 0:
        # 強化学習の場合などはデータが無い
        return 'データなし'

    text = ''
    for input_data in input_data_config:
        text += '  {channel_name}: {s3uri}\n'.format(
            channel_name = input_data['ChannelName'],
            s3uri = input_data['DataSource']['S3DataSource']['S3Uri']
        )
    return text


※ スクリプト中に学習の最後のメトリクスデータを表示できるような処理を含んでいますが、機能しません。Lambdaのランタイム環境のbotocoreのバージョンが低いため、describe_training_jobがFinalMetricDataListに対応していなかったためです。そのうち更新されるだろうという希望のもと、参考程度に残しています。(2019/1/22現在)

環境変数の設定

スクリプトで使用するパラメータを環境変数として設定します。

  • BOT_USERNAME: Slackへ投稿時に表示される名前です。
  • SLACK_CHANNEL: 投稿先チャンネル名です。
  • SLACK_URL: Incoming Webhooks設定時に取得したエンドポイントURLです。このURLにデータをPOSTすることでSlackに投稿することができます。

トリガーの設定

次にLambda関数のトリガーを設定します。 S3の特定のバケットに接尾辞がoutput/model.tar.gzのオブジェクトが保存されるとLambda関数が実行されるように設定し、イベントとして追加します。イベントの追加は対応させるバケットの数だけ必要になります。

イベントの追加が終われば必要な設定は終わりです。右上のSaveでLambda関数を保存します。

動作確認

これまでに動かした学習ジョブのモデルアーティファクトをダウンロードし、同じ場所へアップロードし直すと学習時を再現できます。 実際にやってみると、こんな感じでSlackへ投稿されます。

情報が不足していたり、多かったり、通知内容が見にくかったりすると思うので、必要に応じてスクリプトを修正してください〜。

さいごに

Amazon SageMakerで学習ジョブが完了した際に学習ジョブ情報をSlackに通知するbotを作ってみました。少々強引な方法になってしまいましたが、それっぽいものは出来ました。これでマネジメントコンソールに入らなくても、どういうジョブが行われたかが分かります。Slackへの通知部分を書き換えることで別のサービスとの連携も可能です。使い方によっては便利なものになるかもしれないですね。

お読みくださり、ありがとうございました〜!