AWS Step FunctionsのMapステートを使って全リージョン並列実行する(AWS CDK)
単純に全リージョン処理を行いたいときはLambda内でループ処理すればいいんですが、処理時間を短くしたいケースがあったのでMapステートを使って並列実行してみました。
Mapステートについて詳しく知りたい方は、以下エントリーを読むとわかりやすいです。
AWS Step Functions の Map ステートの挙動を調べてみた。 | DevelopersIO
Mapステートで全リージョン実行するサンプル
作成したステートマシンです。
最初のステップで全リージョンを取得し、そのアウトプットをもとにMapで並列実行します。Map内の処理はリージョンをログに出力するようにしました。
CDK(TypeScript)
ステートマシンを定義している箇所は以下の通りです。
import { Construct } from 'constructs'; import { Stack, StackProps, Duration, aws_iam as iam, aws_lambda as lambda, aws_stepfunctions as sfn, aws_stepfunctions_tasks as tasks, } from 'aws-cdk-lib'; export class AwsCdkMultiRegionSfnStack extends Stack { constructor(scope: Construct, id: string, props?: StackProps) { super(scope, id, props); //IAMロール const lambdaRole = new iam.Role(this, "lambdaRole", { assumedBy: new iam.ServicePrincipal("lambda.amazonaws.com"), }); lambdaRole.addManagedPolicy( iam.ManagedPolicy.fromAwsManagedPolicyName("service-role/AWSLambdaBasicExecutionRole") ); lambdaRole.addManagedPolicy( iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonEC2ReadOnlyAccess") ); //リージョン一覧を取得するLambda const createGetRegionLambdaFunction = new lambda.Function(this, 'createGetRegionLambdaFunction', { code: new lambda.AssetCode("lambda"), runtime: lambda.Runtime.PYTHON_3_9, functionName: "GetRegionLambdaFunction", handler: 'get_regions.handler', role: lambdaRole, timeout: Duration.seconds(180), }); //リージョンを出力するLambda const createPrintRegionLambdaFunction = new lambda.Function(this, 'createPrintRegionLambdaFunction', { code: new lambda.AssetCode("lambda"), runtime: lambda.Runtime.PYTHON_3_9, functionName: "PrintRegionLambdaFunction", handler: 'print_region.handler', role: lambdaRole, timeout: Duration.seconds(180), }); const getRegionsTask = new tasks.LambdaInvoke(this, 'getRegionsTask', { lambdaFunction: createGetRegionLambdaFunction, resultPath: '$.getRegions', }); const printRegionTask = new tasks.LambdaInvoke(this, 'printRegionTask', { lambdaFunction: createPrintRegionLambdaFunction, }); //MAP内エラー取得 const printRegionTaskError = new sfn.Pass(this, 'printRegionTaskError') printRegionTask.addCatch(printRegionTaskError, { resultPath: '$.printRegionTaskError' }) //Mapステート const printRegionMap = new sfn.Map(this, 'printRegionMap', { inputPath: "$.getRegions.Payload.Regions", resultPath: "$.PrintRegion" }); //MAPのイテレーターにRegion出力タスクを指定 printRegionMap.iterator(printRegionTask) //ステートマシン const definition = getRegionsTask.next(printRegionMap); new sfn.StateMachine(this, 'RegionMapStateMachine', { stateMachineName: "RegionMapStateMachine", definition: definition, }); } }
Lambda(Python)
Lambdaはリージョン取得と、出力用で2つを用意しました。
リージョン取得用(getRegionsTask)
リージョンを取得したあと、リストの形式だとステートマシン内で扱いにくいため、オブジェクト(辞書型)の形式にしてreturnしています。
import boto3 def handler(event, context): ec2 = boto3.client('ec2') regions = list(map(lambda x: x['RegionName'], ec2.describe_regions()['Regions'])) ret = [] for region in regions: ret.append( { "Region": region, } ) return {"Regions": ret}
リージョン出力用(printRegionTask)
こちらはインプットとしてRegionを取得してprintだけしています。
def handler(event, context): region = event.get("Region") print(region) return region
実行してみる
そのまま実行してみると、ステートマシンは正常終了します。
CloudWatch Logsを見ると、ほぼ同タイミングで実行されていることが確認できます。
ログを見ると各リージョンがインプットとして渡っていることが確認できました。
このリージョンごとの文字列を使うことで、Map内の処理をリージョンごとに実装できそうですね。
Mapステート内で処理が失敗した時
Mapステートデフォルトの動作として、並列実行している処理が1つでも失敗したときは処理が中断され、ステートマシンが失敗します。今回は全リージョン処理した上で失敗したリージョンを特定したかったので、Mapステート内にエラーをキャッチするステートを入れています。
参考:AWS Step FunctionsステートマシンのMapステート内でエラーをキャッチする(AWS CDK v2) | DevelopersIO
このエラーキャッチのおかげで、Mapステートが途中で失敗しなくなるため全リージョンの処理結果が確認できるようになりました。
注意点としては、この実装するとエラーキャッチした場合でもステートマシン自体は成功したことになる点です。失敗したか気づけるようにエラーキャッチ処理時に通知する実装を入れておきましょう。
おわりに
AWS Step FunctionsステートマシンのMapステートを使って全リージョンに実行できるタスクを作ってみました。全リージョンで並列化したいケースなどでご利用ください。