[レポート] MLモデルのデプロイを自動化するライフサイクル Automated ML model development life cycle #ARC318 #reinvent

[レポート] MLモデルのデプロイを自動化するライフサイクル Automated ML model development life cycle #ARC318 #reinvent

Clock Icon2019.12.03

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

最初に

こんにちはデータアナリティクス事業本部のyoshimです。
今日はre:Invent2019にて行われた「Automated ML model development life cycle」という「Builders Session」の内容についてご紹介するエントリーを書こうと思います。

ワークショップ概要

本ワークショップの概要は下記の通りです。

The machine-learning workflow is an iterative cross-functional process. In this session, you integrate a machine-learning model into a data pipeline. Learn how to automate data preparation, feature engineering, and periodic model tuning. Then see how to incorporate your model into a production workload and monitor model performance. Finally, integrate the model into a continuous deployment system to complete the model development life cycle (MDLC). Please bring your laptop.

一応、Google翻訳したものも載せておきます。

機械学習のワークフローは、部門を超えた反復的なプロセスです。このセッションでは、機械学習モデルをデータパイプラインに統合します。データの準備、機能エンジニアリング、および定期的なモデル調整を自動化する方法を学びます。次に、モデルを運用ワークロードに組み込み、モデルのパフォーマンスを監視する方法を確認します。最後に、モデルを継続的な展開システムに統合して、モデル開発ライフサイクル(MDLC)を完了します。ラップトップを持参してください。

機械学習システムをプロダクションとして運用する際のTipsを得られないか、と思い参加してきました。

目次

1.概要

今回は「ビルダーズセッション」というものに初めて参加してみました。
一般参加者7名+メインスピーカ1名の合計8名が、1つのテーブルに座って、メインスピーカーが定義した課題に沿って一般参加者が開発をしながら疑問点について質問する、というものでした。
英語全然できないマンの私は、最初の自己紹介の時点で心が折れかけたのですが、なんとか最後までやりきることができましたので堂々とエントリーに書いてご紹介できます。

開発内容はタイトルの通り、「SageMakerのモデル更新を自動化する」ために「Lambda」と「SageMaker」の処理を「StepFunctions」を使って制御する、というものです。
具体的には、「モデルのトレーニング」、「モデルの性能評価」をする2つのステートマシーンを作成しました。
下記は最終的に作成するステートマシーンのイメージです。

・モデルのトレーニング

・モデルの性能評価

なお、セッション中はこちらのリポジトリを参考に作業を進めましたが、事前に環境中に用意されていたリソースが多いので、今回は「どういった考え方でどういうことをするのか」と「参考程度のスクリプトの記述」にとどめます。

2.処理内容とセッションで使ったコード

「トレーニング用」、「モデルの性能評価」のステートマシーンでそれぞれどんなことをしているのか、とサンプルコードを記述する前に、それぞれどんなことをしているのかを整理します。

トレーニング用ステートマシーン

このステートマシーンでは、SageMakerのトレーニングJOBを実行した後に、モデルの情報をDynamoDBに格納しています。
SageMakerのトレーニングJOBを実行するまではおきまりの流れかと思いますが、「DynamoDBにモデルの情報を格納している」といった点が特徴的です。
なんでこんなことをしているのかというと、「モデルの性能評価」用のステートマシーンでモデルの情報を参照する際に利用しています。

「トレーニング用ステートマシーン」、「モデルの性能評価用ステートマシーン」それぞれを「モデル名」変数をキーとして結びつけることで、それぞれ異なったタイミングでステートマシーンを実行しても必要な情報が取得できます。

今回のセッションでは「モデル名」とそれに付随して作成された「SageMakerのモデル名」、「モデル作成日時」をDynamoDBに格納していますが、ユースケースによって格納する値が変わってくるのかと思います。 (テスト用データを別に保持しておいて、そのパスを参照して「モデルの性能評価用ステートマシーンで評価する」等)

モデルの性能評価用ステートマシーン

SageMakerのバッチ変換機能を使って、モデルの評価をします。
評価の結果があらかじめ決めた閾値よりも悪いようなら「トレーニング用ステートマシーン」を再実行します。
シンプルですが、「運用が始まったらモデル更新も自動化したいけど、評価指標が悪いようなら更新したくない」といった場合には良さそうな仕組みです。

肝となるのは、後述する下記の2点です。

3.処理内容詳細

以下、「トレーニング用ステートマシーン」、「モデルの評価用ステートマシーン」で実施している処理内容についてもう少しだけ言及します。

3-1.トレーニング用のステートマシーンで何をしているのか

下記に各処理の詳細を箇条書きで記述しておきます。
目立った処理内容は特にないです。

  • Initialize
    • 「SageMakerが参照するパス」や「トレーニングJOBの情報」等を生成します。
  • Train Model
    • 「Initialize」のパラメータを引き継いでSageMakerのトレーニングJOBを実行
  • Create Model
    • トレーニングJOBの結果から、モデルを生成します。
  • Register Model
    • DynamoDBにモデルの情報を格納します。

下記は実際に利用したコードです。

Initialize

// node.js
const AWS = require("aws-sdk");
const s3 = new AWS.S3();

exports.handler = async (event, context, callback) => {
    console.log('Received event', JSON.stringify(event, null, 2));

    const initializeWorkflow = async (event) => {
        // get event parameters
        const modelName = event.WorkflowInput.ModelName;
        const trainingImage = event.WorkflowInput.TrainingImage;
        const s3RootBucket = event.S3RootBucket;
        const dataDate = DateUtil.getDateFromUtcString(event.WorkflowInput.DataDate);

        // generate training job name
        const runDate = new Date();
        const trainingJobName = `${modelName}-${DateUtil.formatDateTimeStringShort(runDate)}`;

        // generate training input folder
        const trainingInput = {
            S3Key: `input-data/${event.WorkflowInput.DataDate}/train`,
            S3Uri: `s3://${s3RootBucket}/input-data/${event.WorkflowInput.DataDate}/train`
        };

        // generate training validation folder
        const validationInput = {
            S3Key: `input-data/${event.WorkflowInput.DataDate}/validation`,
            S3Uri: `s3://${s3RootBucket}/input-data/${event.WorkflowInput.DataDate}/validation`
        };

        // generate training output folder
        const trainingOutput = {
            S3Key: `model/${event.WorkflowInput.DataDate}/${DateUtil.formatDateTimeStringShort(runDate)}`,
            S3Uri: `s3://${s3RootBucket}/model/${event.WorkflowInput.DataDate}/${DateUtil.formatDateTimeStringShort(runDate)}`
        }

        // generate training config
        const trainingConfig = {
            TrainingJobName: trainingJobName,
            TrainingInput: trainingInput,
            ValidationInput: validationInput,
            TrainingOutput: trainingOutput
        };

        // upload input request to S3
        const s3UploadParams = {
            Bucket: s3RootBucket,
            Key: 'workflow_request.json',
            Body: JSON.stringify(event, null, 2),
            ContentType: 'application/json'
        };
        await s3.putObject(s3UploadParams);

        // build return object
        const ret = {
            WorkflowRequest: event.WorkflowInput,
            S3RootBucket: s3RootBucket,
            DataDate: dataDate,
            RunDate: runDate,
            TrainingConfig: trainingConfig
        };

        return ret;
    };

    return initializeWorkflow(event).then((result) => {
        callback(null, result);
    });
};

class DateUtil {
    static getDateFromUtcString(dateString) {
        return new Date(dateString);
    }

    static formatDateString(date) {
        return date.toISOString().split('T')[0];
    }

    static formatDateTimeString(date) {
        return date.toISOString();
    }

    static formatDateStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        return year + monthFull + dayOfMonthFull;
    }

    static formatDateTimeStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const hours = date.getUTCHours();
        const minutes = date.getUTCMinutes();
        const seconds = date.getUTCSeconds();

        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        const hoursFull = hours < 10 ? '0' + hours : hours;
        const minutesFull = minutes < 10 ? '0' + minutes : minutes;
        const secondsFull = seconds < 10 ? '0' + seconds : seconds;

        return year + monthFull + dayOfMonthFull + 'T' + hoursFull + minutesFull + secondsFull;
    }
};

Train Model

単純に、Stepfunctionsでarn:aws:states:::sagemaker:createTrainingJob.syncを記述していく。
後述する「ステートマシーンのスニペット」をご参照ください。

Create Model

単純に、Stepfunctionsでarn:aws:states:::sagemaker:createModelを記述していく。
後述する「ステートマシーンのスニペット」をご参照ください。

Register model

// node.js

const AWS = require("aws-sdk");
const dynamodb = new AWS.DynamoDB.DocumentClient({ apiVersion: '2012-08-10' });

exports.handler = async (event, context, callback) => {
    console.log('Received event', JSON.stringify(event, null, 2));

    const registerModel = async (event) => {
        // get event parameters
        const modelName = event.ModelName;
        const sageMakerModelName = event.SageMakerModelName;
        const timestamp = DateUtil.formatDateTimeString(
            DateUtil.getDateFromUtcString(event.Timestamp));

        // build item
        var params = {
            TableName: "MODEL_REGISTRY",
            Item: {
                "MODEL_NAME": modelName,
                "SAGEMAKER_MODEL_NAME": sageMakerModelName,
                "MODEL_TIMESTAMP": timestamp
            }
        };

        // put item
        try {
            await dynamodb.put(params).promise();
        } catch (error) {
            return {
                statusCode: 400,
                error: `Could not post: ${error.stack}`
            };
        }

        // build return object
        const ret = {
            ModelName: modelName,
            SageMakerModelName: sageMakerModelName,
            Timestamp: timestamp
        };

        return ret;
    };

    return registerModel(event).then((result) => {
        callback(null, result);
    });
};

class DateUtil {
    static getDateFromUtcString(dateString) {
        return new Date(dateString);
    }

    static formatDateString(date) {
        return date.toISOString().split('T')[0];
    }

    static formatDateTimeString(date) {
        return date.toISOString();
    }

    static formatDateStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        return year + monthFull + dayOfMonthFull;
    }

    static formatDateTimeStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const hours = date.getUTCHours();
        const minutes = date.getUTCMinutes();
        const seconds = date.getUTCSeconds();

        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        const hoursFull = hours < 10 ? '0' + hours : hours;
        const minutesFull = minutes < 10 ? '0' + minutes : minutes;
        const secondsFull = seconds < 10 ? '0' + seconds : seconds;

        return year + monthFull + dayOfMonthFull + 'T' + hoursFull + minutesFull + secondsFull;
    }
};

3-2.モデルの性能評価用のステートマシーンで何をしているのか

このステートマシーンでは、「モデルの評価」をして「結果が悪いようなら再トレーニング」をしています。 下記に各処理の詳細を箇条書きで記述しておきます。

  • Initialize
    • 評価対象のモデルや、s3ファイルのパス情報等を生成する
  • Find Model
    • 「Initialize」のパラメータを引き継いでモデルの情報を取得
  • Inference
  • Monitor Model Performance
    • バッチ変換で得た結果から評価指標を取得する
  • Is Retraining Needed?
    • 評価指標値があらかじめ決めた閾値よりも悪い場合は「Re-train」、そうでない場合は「Finalize」に処理を分岐する
  • Re-train
    • 「トレーニング用のステートマシーン」を実行する

下記は実際に利用したコードです。

Initialize

// node.js

const AWS = require("aws-sdk");
const s3 = new AWS.S3();

exports.handler = async (event, context, callback) => {
    console.log('Received event', JSON.stringify(event, null, 2));

    const initializeWorkflow = async (event) => {
        // get event parameters
        const modelName = event.WorkflowInput.ModelName;
        const dataDate = DateUtil.getDateFromUtcString(event.WorkflowInput.DataDate);
        const s3RootBucket = event.S3RootBucket;

        // generate inference job name
        const runDate = new Date();
        const inferenceJobName = `${modelName}-${DateUtil.formatDateTimeStringShort(runDate)}`;

        // generate inference input folder
        const inferenceInput = {
            S3Key: `input-data/${event.WorkflowInput.DataDate}/holdout_data`,
            S3Uri: `s3://${s3RootBucket}/input-data/${event.WorkflowInput.DataDate}/holdout_data`
        };

        // generate inference output folder
        const inferenceOutput = {
            S3Key: `output-data/${event.WorkflowInput.DataDate}/${DateUtil.formatDateTimeStringShort(runDate)}`,
            S3Uri: `s3://${s3RootBucket}/output-data/${event.WorkflowInput.DataDate}/${DateUtil.formatDateTimeStringShort(runDate)}`
        };

        // generate inference config
        const inferenceConfig = {
            InferenceJobName: inferenceJobName,
            InferenceInput: inferenceInput,
            InferenceOutput: inferenceOutput
        };

        // generate holdout target input folder
        const monitorInput = {
            TargetS3Uri: `s3://${s3RootBucket}/input-data/${event.WorkflowInput.DataDate}/holdout_target/holdout_target.csv`,
            PredictionS3Uri: `s3://${s3RootBucket}/output-data/${event.WorkflowInput.DataDate}/${DateUtil.formatDateTimeStringShort(runDate)}/holdout_data.csv.out`
        }

        // generate monitor config
        const monitorConfig = {
            MonitorInput: monitorInput
        };

        // upload input request to S3
        const s3UploadParams = {
            Bucket: s3RootBucket,
            Key: 'workflow_request.json',
            Body: JSON.stringify(event, null, 2),
            ContentType: 'application/json'
        };
        await s3.putObject(s3UploadParams);

        // build return object
        const ret = {
            WorkflowRequest: event.WorkflowInput,
            S3RootBucket: s3RootBucket,
            DataDate: dataDate,
            RunDate: runDate,
            InferenceConfig: inferenceConfig,
            MonitorConfig: monitorConfig
        };

        return ret;
    };

    return initializeWorkflow(event).then((result) => {
        callback(null, result);
    });
};

class DateUtil {
    static getDateFromUtcString(dateString) {
        return new Date(dateString);
    }

    static formatDateString(date) {
        return date.toISOString().split('T')[0];
    }

    static formatDateTimeString(date) {
        return date.toISOString();
    }

    static formatDateStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        return year + monthFull + dayOfMonthFull;
    }

    static formatDateTimeStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const hours = date.getUTCHours();
        const minutes = date.getUTCMinutes();
        const seconds = date.getUTCSeconds();

        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        const hoursFull = hours < 10 ? '0' + hours : hours;
        const minutesFull = minutes < 10 ? '0' + minutes : minutes;
        const secondsFull = seconds < 10 ? '0' + seconds : seconds;

        return year + monthFull + dayOfMonthFull + 'T' + hoursFull + minutesFull + secondsFull;
    }
};

Find Model

// node.js

const AWS = require("aws-sdk");
const dynamodb = new AWS.DynamoDB.DocumentClient({ apiVersion: '2012-08-10' });

exports.handler = async (event, context, callback) => {
    console.log('Received event', JSON.stringify(event, null, 2));

    const findModel = async (event) => {
        // get event parameters
        const modelName = event.ModelName;

        // build item
        var params = {
            TableName: "MODEL_REGISTRY",
            KeyConditionExpression: 'MODEL_NAME = :ModelName',
            ProjectionExpression: 'SAGEMAKER_MODEL_NAME',
            ExpressionAttributeValues: {
                ':ModelName': modelName
            },
            ConsistentRead: true,
            ScanIndexForward: false,
            Limit: 1
        };

        // put item
        let data;
        try {
            data = await dynamodb.query(params).promise();
        } catch (error) {
            return {
                statusCode: 400,
                error: `Could not post: ${error.stack}`
            };
        }

        if (data.Items.length == 0) {
            throw "Unable to find model";
        }

        // build return object
        const ret = {
            SageMakerModelName: data.Items[0]['SAGEMAKER_MODEL_NAME']
        };

        return ret;
    };

    return findModel(event).then((result) => {
        callback(null, result);
    });
};

class DateUtil {
    static getDateFromUtcString(dateString) {
        return new Date(dateString);
    }

    static formatDateString(date) {
        return date.toISOString().split('T')[0];
    }

    static formatDateTimeString(date) {
        return date.toISOString();
    }

    static formatDateStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        return year + monthFull + dayOfMonthFull;
    }

    static formatDateTimeStringShort(date) {
        const year = date.getUTCFullYear().toString();
        const month = date.getUTCMonth() + 1;
        const dayOfMonth = date.getUTCDate();
        const hours = date.getUTCHours();
        const minutes = date.getUTCMinutes();
        const seconds = date.getUTCSeconds();

        const monthFull = month < 10 ? '0' + month : month;
        const dayOfMonthFull = dayOfMonth < 10 ? '0' + dayOfMonth : dayOfMonth;
        const hoursFull = hours < 10 ? '0' + hours : hours;
        const minutesFull = minutes < 10 ? '0' + minutes : minutes;
        const secondsFull = seconds < 10 ? '0' + seconds : seconds;

        return year + monthFull + dayOfMonthFull + 'T' + hoursFull + minutesFull + secondsFull;
    }
};

Inference

「arn:aws:states:::sagemaker:createTransformJob.sync」でバッチ変換JOBを実行します。 後述する「ステートマシーンのスニペット」をご参照ください。

Monitor Model Performance

評価用データとバッチ変換結果のデータを読み込んで、評価指標を計算しています。

# python 2.7
import boto3
import os
import numpy as np
import pandas as pd

from sklearn.metrics import roc_auc_score


s3_client = boto3.client('s3')


def lambda_handler(event, context):

    # load the dataset date and s3 input path
    # dataset_date = event['DatasetDate']
    # run_date = event['RunDate']
    preduction_path = event['PredictionInput']
    target_path = event['TargetInput']

    # read the target and inference result
    target = pd.read_csv(target_path, header=None, names=['target'])
    prediction = pd.read_csv(
        preduction_path, header=None, names=['prediction'])

    # compute the monitoring result
    monitoring_result = roc_auc_score(target.values, prediction.values)

    # print('By {}, for the dataset date {}, the AUC is {}'.format(run_date, dataset_date, monitoring_result))

    return {'ModelPerformance': monitoring_result}

Is Retraining Needed?

choiceを使って、「Monitor Model Performance」の結果の値に応じて、StepFunctionsの中で次の処理をどうするかを分岐させます。
後述する「ステートマシーンのスニペット」をご参照ください。

Re-train

「トレーニング用のステートマシーン」をStepfunctionsの「arn:aws:states:::states:startExecution.sync」で実行します。
Manage AWS Step Functions Executions as an Integrated Service 後述する「ステートマシーンのスニペット」をご参照ください。

4.ステートマシーンのスニペット

今回利用したステートマシーンのスニペットも記述しておきます。
もし何かの参考になったら幸いです。

トレーニング用ステートマシーン

{
  "StartAt": "Initialize",
  "States": {
    "Initialize": {
      "Type": "Task",
      "Resource": "<「reinvent-mdlc-training-initialize-workflow」Lambda関数のARN>",
      "Parameters": {
        "WorkflowInput.$": "$",
        "S3RootBucket": "<利用するS3バケットのバケット名>"
      },
      "ResultPath": "$",
      "Next": "Train Model",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Train Model": {
      "Type": "Task",
      "Resource": "arn:aws:states:::sagemaker:createTrainingJob.sync",
      "Parameters": {
        "TrainingJobName.$": "$.TrainingConfig.TrainingJobName",
        "AlgorithmSpecification": {
          "TrainingImage": "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1",
          "TrainingInputMode": "File"
        },
        "RoleArn": "<SageMakerが利用するARN>",
        "HyperParameters": {
          "objective": "binary:logistic",
          "colsample_bytree": "0.7",
          "max_depth": "4",
          "eta": "0.2",
          "gamma": "4",
          "min_child_weight": "6",
          "subsample": "0.7",
          "learning_rate": "0.075",
          "silent": "0",
          "num_round": "200",
          "seed": "0"
        },
        "InputDataConfig": [
          {
            "ChannelName": "train",
            "DataSource": {
              "S3DataSource": {
                "S3DataType": "S3Prefix",
                "S3Uri.$": "$.TrainingConfig.TrainingInput.S3Uri",
                "S3DataDistributionType": "FullyReplicated"
              }
            },
            "CompressionType": "None",
            "RecordWrapperType": "None",
            "ContentType": "text/csv"
          },
          {
            "ChannelName": "validation",
            "DataSource": {
              "S3DataSource": {
                "S3DataType": "S3Prefix",
                "S3Uri.$": "$.TrainingConfig.ValidationInput.S3Uri",
                "S3DataDistributionType": "FullyReplicated"
              }
            },
            "CompressionType": "None",
            "RecordWrapperType": "None",
            "ContentType": "text/csv"
          }
        ],
        "OutputDataConfig": {
          "S3OutputPath.$": "$.TrainingConfig.TrainingOutput.S3Uri"
        },
        "ResourceConfig": {
          "InstanceCount": 1,
          "InstanceType": "ml.m5.4xlarge",
          "VolumeSizeInGB": 50
        },
        "StoppingCondition": {
          "MaxRuntimeInSeconds": 3600
        }
      },
      "ResultPath": "$.TrainingOutput",
      "Next": "Create Model",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Create Model": {
      "Type": "Task",
      "Resource": "arn:aws:states:::sagemaker:createModel",
      "Parameters": {
        "ModelName.$": "$.TrainingConfig.TrainingJobName",
        "ExecutionRoleArn": "<SageMakerが利用するIAMロールのARN>",
        "PrimaryContainer": {
          "Image": "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1",
          "ModelDataUrl.$": "$.TrainingOutput.ModelArtifacts.S3ModelArtifacts"
        }
      },
      "ResultPath": "$.CreateModelOutput",
      "Next": "Register Model",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Register Model": {
      "Type": "Task",
      "Resource": "<reinvent-mdlc-training-register-model>Lambda関数のARN",
      "Parameters": {
        "ModelName.$": "$.WorkflowRequest.ModelName",
        "SageMakerModelName.$": "$.TrainingConfig.TrainingJobName",
        "Timestamp.$": "$.RunDate"
      },
      "ResultPath": "$.RegisterModelOutput",
      "Next": "Finalize",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Finalize": {
      "Type": "Pass",
      "End": true
    },
    "Handle Error": {
      "Type": "Pass",
      "Next": "Failure"
    },
    "Failure": {
      "Type": "Fail"
    }
  }
}

  • モデルの性能評価用ステートマシーン
{
  "StartAt": "Initialize",
  "States": {
    "Initialize": {
      "Type": "Task",
      "Resource": "<「reinvent-mdlc-batch-inference-initialize-workflow」Lambda関数のARN>",
      "Parameters": {
        "WorkflowInput.$": "$",
        "S3RootBucket": "<利用するS3バケットのバケット名>"
      },
      "ResultPath": "$",
      "Next": "Find Model",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Find Model": {
      "Type": "Task",
      "Resource": "<「function:reinvent-mdlc-batch-inference-find-model」Lambda関数のARN>",
      "Parameters": {
        "ModelName.$": "$.WorkflowRequest.ModelName"
      },
      "ResultPath": "$.FindModelOutput",
      "Next": "Inference",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Inference": {
      "Type": "Task",
      "Resource": "arn:aws:states:::sagemaker:createTransformJob.sync",
      "Parameters": {
        "TransformJobName.$": "$.InferenceConfig.InferenceJobName",
        "ModelName.$": "$.FindModelOutput.SageMakerModelName",
        "TransformInput": {
          "DataSource": {
              "S3DataSource": {
                  "S3DataType": "S3Prefix",
                  "S3Uri.$": "$.InferenceConfig.InferenceInput.S3Uri"
              }
          },
          "ContentType": "text/csv",
          "SplitType": "Line",
          "CompressionType": "None"
        },
        "TransformOutput": {
          "S3OutputPath.$": "$.InferenceConfig.InferenceOutput.S3Uri",
          "Accept": "text/csv",
          "AssembleWith": "Line"
        },
        "TransformResources": {
          "InstanceType": "ml.c4.2xlarge",
          "InstanceCount": 10
        }
      },
      "ResultPath": "$.InferenceOutput",
      "Next": "Monitor Model Performance",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Monitor Model Performance": {
      "Type": "Task",
      "Resource": "<「reinvent-mdlc-monitoring-model-performance」Lambda関数のArn>",
      "Parameters": {
        "PredictionInput.$": "$.MonitorConfig.MonitorInput.PredictionS3Uri",
        "TargetInput.$": "$.MonitorConfig.MonitorInput.TargetS3Uri"
      },
      "ResultPath": "$.MonitoringOutput",
      "Next": "Is Retraining Needed?",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Is Retraining Needed?": {
      "Type": "Choice",
      "Choices": [{
        "Variable": "$.MonitoringOutput.ModelPerformance",
        "NumericLessThan": 1.0,
        "Next": "Re-train"
      }, {
        "Variable": "$.MonitoringOutput.ModelPerformance",
        "NumericGreaterThanEquals": 1.0,
        "Next": "Re-train"
      }],
      "Default": "Finalize"
    },
    "Re-train": {
      "Type": "Task",
      "Resource": "arn:aws:states:::states:startExecution.sync",
      "Parameters": {
        "Input": {
          "ModelName.$": "$.WorkflowRequest.ModelName",
          "DataDate.$": "$.WorkflowRequest.DataDate"
        },
        "StateMachineArn": "<「reinvent-mdlc-training-workflow」>ステートマシーンのARN"
      },
      "Next": "Finalize",
      "Catch": [
      {
        "ErrorEquals": ["States.ALL"],
        "Next": "Handle Error",
        "ResultPath": "$.Error"
      }
      ]
    },
    "Finalize": {
      "Type": "Pass",
      "End": true
    },
    "Handle Error": {
      "Type": "Pass",
      "Next": "Failure"
    },
    "Failure": {
      "Type": "Fail"
    }
  }
}

5.まとめ

今回、初めてのビルダーズセッション参加ということで、「英語力が無いと厳しいんじゃないか...」とかなりビビっており、最初の自己紹介で「やっちまったか...」と焦ったのですが、自己紹介の後は作業がメインだったのでなんとかなりました。
(それはそれで良くないけど)

セッションの中で取り上げる技術領域にある程度の事前知識は必要となりますが結構面白かったので、もしre:Inventに参加する予定がある方にはビルダーズセッションの参加をお勧めします。

この記事をシェアする

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.