CloudWatch EventsとLambdaを使って、SageMaker の学習ジョブ終了通知bot(Slack)を作ってみた – 機械学習 on AWS Advent Calendar 2019

『機械学習 on AWS Advent Calendar 2019』の12日目

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

当エントリは『機械学習 on AWS Advent Calendar 2019』の12日目のエントリです。

今回は「Amazon CloudWatch EventsとLambdaを使って、SageMaker の学習ジョブ終了通知botを作ってみた」についてご紹介します。 re:Invent関連の何かにしようかと思ってたんですが、CloudWatch EventsがSageMakerに対応していることを見つけてテンションが上がり、この内容に変更しました。

以前、以下のエントリで同様の通知 bot を作成しました。

Amazon SageMakerの学習ジョブ完了通知bot(Slack)を作ってみた

しかし、その際には CloudWatch Events が SageMaker に対応していたかどうかを知らず、学習ジョブ完了時のモデルアーティファクトが S3 に保存される際の S3 イベントを利用して実装しました。 今回は、CloudWatch Events のSageMaker Training Job State Changeを利用して、学習ジョブの完了通知 bot を再実装したいと思います。また、前回はできなかった学習ジョブの失敗にも対応させます。

SageMaker Debuggerのルールを設定した学習ジョブを実行した際に一件の学習ジョブに対して複数の完了通知が投稿されるというバグが見つかったため、スクリプトを修正しました。(2020年1月10日)

SageMakerのイベントタイプ

CloudWatch Events が対応している SageMaker のイベントタイプは次の通りです。

  • SageMaker Training Job State Change
  • SageMaker Hyperparameter Tuning Job State Change
  • SageMaker Transform Job State Change
  • SageMaker Endpoint Config State Change
  • SageMaker Endpoint State Change
  • SageMaker Model State Change
  • SageMaker Notebook Instance State Change
  • SageMaker Notebook Lifecycle Config State Change
  • SageMaker Algorithm State Change
  • SageMaker Model Package State Change

これら以外のイベントについてもAWS API Call via CloudTrailを用いることで、List/Get/Describe以外の API 操作であればイベントとして利用できます。

SageMaker Training Job State Changeの場合だと、イベントデータは次のようになります。

{
    "version": "0",
    "id": "844e2571-85d4-695f-b930-0153b71dcb42",
    "detail-type": "SageMaker Training Job State Change",
    "source": "aws.sagemaker",
    "account": "123456789012",
    "time": "2018-10-06T12:26:13Z",
    "region": "us-east-1",
    "resources": [
        "arn:aws:sagemaker:us-east-1:123456789012:training-job/kmeans-1"
    ],
    "detail": {
        "TrainingJobName": "89c96cc8-dded-4739-afcc-6f1dc936701d",
        "TrainingJobArn": "arn:aws:sagemaker:us-east-1:123456789012:training-job/kmeans-1",
        "TrainingJobStatus": "Completed",
        "SecondaryStatus": "Completed",
        "HyperParameters": {
            "Hyper": "Parameters"
        },
        "AlgorithmSpecification": {
            "TrainingImage": "TrainingImage",
            "TrainingInputMode": "TrainingInputMode"
        },
        "RoleArn": "a little teapot, some little teapot",
        "InputDataConfig": [
            {
                "ChannelName": "Train",
                "DataSource": {
                    "S3DataSource": {
                        "S3DataType": "S3DataType",
                        "S3Uri": "S3Uri",
                        "S3DataDistributionType": "S3DataDistributionType"
                    }
                },
                "ContentType": "ContentType",
                "CompressionType": "CompressionType",
                "RecordWrapperType": "RecordWrapperType"
            }
        ],
        "OutputDataConfig": {
            "KmsKeyId": "KmsKeyId",
            "S3OutputPath": "S3OutputPath"
        },
        "ResourceConfig": {
            "InstanceType": "InstanceType",
            "InstanceCount": 3,
            "VolumeSizeInGB": 20,
            "VolumeKmsKeyId": "VolumeKmsKeyId"
        },
        "VpcConfig": {
            
        },
        "StoppingCondition": {
            "MaxRuntimeInSeconds": 60
        },
        "CreationTime": "2018-10-06T12:26:13Z",
        "TrainingStartTime": "2018-10-06T12:26:13Z",
        "TrainingEndTime": "2018-10-06T12:26:13Z",
        "LastModifiedTime": "2018-10-06T12:26:13Z",
        "SecondaryStatusTransitions": [
            
        ],
        "Tags": {
            
        }
    }
}

やってみる

では、実際に学習ジョブの終了通知botを作成してみます。

Slack

SlackのワークスペースにIncoming Webhooksを追加し、メッセージを送信するために使用するWebhook URLを取得します。

Lambda関数の作成

次にLambda関数を作成します。ランタイムはPython 3.8を選択します。

スクリプト

Lambda関数の編集画面に移動したら、次は学習ジョブのステータスが変更された際に実行するスクリプトを入力します。 スクリプトは学習ジョブ情報を取得し、その情報を加工したものをSlackに通知するという内容です。

※使用したコンテナイメージ名など欲しい情報は色々あるかと思います。必要に応じて、スクリプトを修正してご利用ください。

import os
import json
import urllib
import boto3
from datetime import timezone, timedelta
 
# パラメータを読み込む
BOT_USERNAME = os.environ['BOT_USERNAME'] 
SLACK_URL = os.environ['SLACK_URL'] 
SLACK_CHANNEL = os.environ['SLACK_CHANNEL'] 
 
# インスタンスタイプとリージョンに応じた価格取得時に指定する地域名(雑だけどもいい方法がないので、とりあえず辞書で対応)
REGION_MAP = {
    'us-gov-west-1' : 'AWS GovCloud (US)',
    'ap-south-1' : 'Asia Pacific (Mumbai)',
    'ap-northeast-2' : 'Asia Pacific (Seoul)',
    'ap-southeast-1' : 'Asia Pacific (Singapore)',
    'ap-southeast-2' : 'Asia Pacific (Sydney)',
    'ap-northeast-1' : 'Asia Pacific (Tokyo)',
    'ca-central-1' : 'Canada (Central)',
    'eu-central-1' : 'EU (Frankfurt)',
    'eu-west-1' : 'EU (Ireland)',
    'eu-west-2' : 'EU (London)',
    'us-east-1' : 'US East (N. Virginia)',
    'us-east-2' : 'US East (Ohio)',
    'us-west-1' : 'US West (N. California)',
    'us-west-2' : 'US West (Oregon)',
}
 
# JSTを定義する
JST = timezone(timedelta(hours=9), 'JST')
 
 
def lambda_handler(event, context):
    """Lambdaから呼ばれるコールバック関数"""
    job_status = event['detail']['TrainingJobStatus']

    if job_status not in ['Completed', 'Failed', 'Stopped']:
      return {
        'statusCode': 200,
        'body': json.dumps('Nothing to do')
      }
      
    # 学習ジョブに紐づいたSageMaker Debuggerのルール評価ジョブのステータス変化によってもイベントは発火されるので、完了時以外は無視する
    eval_statuses = event['detail'].get("DebugRuleEvaluationStatuses", None)
    if eval_statuses and len(eval_statuses) > 0:
        for status in eval_statuses:
            if status['RuleEvaluationStatus'] in ['InProgress', 'Stopping']:
                return {
                    'statusCode': 200,
                    'body': json.dumps('Nothing to do')
                }

    job_name = event['detail']['TrainingJobName']
     
    # ジョブ情報を作成
    job_info = create_training_job_info(job_name)
     
    # Slackにメッセージを送信する
    send_slack_message(job_info)
     
    return {
        'statusCode': 200,
        'body': json.dumps('Done')
    }
 
def send_slack_message(text):
    """Slackにメッセージを送信する"""
    data = {
        'username':BOT_USERNAME, # 表示名
        'text':text, # 内容
        'channel': SLACK_CHANNEL # 送信先チャンネル
    }
    method = "POST"
    headers = {"Content-Type" : "application/json"}
    req = urllib.request.Request(SLACK_URL, method=method, data=json.dumps(data).encode(), headers=headers)
    with urllib.request.urlopen(req) as res:
        body = res.read()
 
    return body
 
 
def create_training_job_info(job_name):
    """ジョブ名とリージョンから学習ジョブ情報を作成する"""
    sess = boto3.Session()
    sm = sess.client('sagemaker')
    region = sess.region_name
     
    # ジョブデータを取得する
    job_data = sm.describe_training_job(TrainingJobName=job_name)

    # インスタンスタイプ
    instance_type = job_data['ResourceConfig']['InstanceType']
     
    # ジョブ実行時間
    job_total_seconds = job_data['TrainingTimeInSeconds']
    job_billable_seconds = job_data.get('BillableTimeInSeconds')  # スポットインスタンス利用を考慮した課金対象時間
    if not job_billable_seconds:
        job_billable_seconds = job_total_seconds
     
    # インスタンス価格を取得する
    instance_price = get_instance_price(region, instance_type)

     
    job_info = '''
[*{job_status}*] *{job_name}*

概算費用(USD): ${job_cost:,.3f}
学習時間: {job_time}({job_start_time}〜{job_end_time})
課金時間: {job_billable_time_text}
インスタンスタイプ: {instance_type} ({instance_price:,.3f} USD/h)
モデルアーティファクト: {model_artifacts_uri}

入力データ:
{input_data}
メトリクス:
{metric_data}
ジョブ詳細: <https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}>

ステータス: {job_status_text}
    '''.format(
        region = region,
        job_status = job_data['TrainingJobStatus'],
        job_status_text = create_job_status_text(job_data['TrainingJobStatus'], job_data['SecondaryStatus'], job_data.get('FailureReason', None)),
        job_cost = job_billable_seconds / 3600 * float(instance_price),
        job_name = job_name,
        instance_type = instance_type,
        instance_price = float(instance_price),
        job_start_time = job_data['TrainingStartTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'),
        job_end_time = job_data['TrainingEndTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'),
        job_time = create_time_text(job_total_seconds),
        input_data = create_input_data_text(job_data['InputDataConfig']),
        model_artifacts_uri = job_data['ModelArtifacts']['S3ModelArtifacts'],
        metric_data = create_metric_text(job_data['FinalMetricDataList']) if 'FinalMetricDataList' in job_data else 'データなし\n',
        job_billable_time_text = create_time_text(job_billable_seconds)
    )
    return job_info
 
 
def get_first_element(dict_obj):
    """辞書の最初の要素を取得する"""
     
    return next(iter(dict_obj.values()))
     
def get_instance_price(region, instance_type):
    """リージョンとインスタンスタイプに応じた価格を取得する"""
     
    # 対応していないリージョンがあるため、決め打ちでus-east-1を使う
    pricing = boto3.client('pricing', region_name='us-east-1')
     
 
    # 対応するリージョンとインスタンスタイプの料金を取得する
    response = pricing.get_products(
        ServiceCode='AmazonSageMaker',
        Filters=[
            {
                'Type': 'TERM_MATCH',
                'Field': 'location',
                'Value': REGION_MAP[region]
            },
            {
                'Type': 'TERM_MATCH',
                'Field': 'instanceType',
                'Value': instance_type + '-Training'
            }
        ]
    )
    return get_first_element(get_first_element(json.loads(response['PriceList'][0])['terms']['OnDemand'])['priceDimensions'])['pricePerUnit']['USD']
 
def create_job_status_text(status, secondary_status, failed_reason):
    """学習ジョブのステータス情報"""
    if status == 'Completed':
        return status
    elif status == 'Failed':
        return f'{status}\n```{failed_reason}```'

    return f'{status}({secondary_status})'

def create_time_text(total_seconds):
    """秒数から読みやすい時間表記を作る"""
 
    days, remainder = divmod(total_seconds, 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, seconds = divmod(remainder, 60)
    time_text = ''
    if days > 0:
        time_text += str(int(days))+'日'
    if hours > 0:
        time_text += str(int(hours))+'時間'
    if minutes > 0:
        time_text += str(int(minutes))+'分'
 
    time_text += str(int(seconds))+'秒'
     
    return time_text
 
def create_metric_text(metric_list):
    """メトリクスデータを作成する"""
    text = ''
    for metric_data in metric_list:
        text += '  {metric_name}: {metric_value}\n'.format(
            metric_name = metric_data['MetricName'],
            metric_value = metric_data['Value']
        )
    return text
 
def create_input_data_text(input_data_config):
    """入力データ情報を作成する"""
 
    if len(input_data_config) == 0:
        # 強化学習の場合などはデータが無い
        return 'データなし'
 
    text = ''
    for input_data in input_data_config:
        text += '  {channel_name}: {s3uri}\n'.format(
            channel_name = input_data['ChannelName'],
            s3uri = input_data['DataSource']['S3DataSource']['S3Uri']
        )
    return text

上のスクリプトをFunction codeとして入力します。

環境変数としてSlackのIncoming webhookの通知先URLやチャンネル名等を設定します。

CloudWatch Eventsを使ったトリガー作成

ここで一旦Lambda関数はSaveし、CloudWatch Eventsでのトリガー作成に移ります。 どういうイベントが来た時にLambda関数を実行するかのルールを作成します。

イベントタイプはSageMaker Training Job State Changeを利用します。

ターゲットは先ほど作成したLambda関数を指定します。

次に名前や説明を入力し、ルールの作成を完了します。

改めてLambda関数の画面に戻ると、トリガーとしてCloudWatch Eventsが設定されています。

確認

次のようなテスト用のイベントを作成し、動作テストしてみます。 ジョブ名は実際に存在する学習ジョブ名を指定する必要があります。

テストを実行すると、、、

Slackに通知がきました。 学習ジョブが正常に終了した場合は次のように通知が来ます。

さいごに

『機械学習 on AWS Advent Calendar 2019』 12日目として、CloudWatch Events と Lambda を利用した、SageMaker の学習ジョブの終了通知botの作成についてお伝えしました。CloudWatch Eventsを利用することで、SageMaker上で様々な機械学習ワークフローを構築することができそうです。今後も色々試してその内容をご紹介していきたいと思います。

参考