[備忘録] EC2に立てた機械学習の推論サーバをライフサイクルフックでVPC LambdaからウォームアップしてからAuto Scalingで増やしてみた。

2022.05.17

せーのでございます。
最近、ちょっとややこしい構築をしたので備忘録として残しておこうと思います。

経緯

機械学習のモデルをEC2上にデプロイしFlask+uWSGI+nginxでAPIサーバ化しました(Sagemakerを使わなかったのはC6iインスタンスを使いたかったから)。
さて、こちらのモデルですが初回アクセス時にモデルがロードされ、また初回の推論のみやや時間がかかってしまうので、ウォームアップとしてLambdaから数回事前に推論を走らせておいてから運用したいと考えました。

そしてテストも終わり、本番運用に向けてアーキテクチャを考えることになりました。

リクエスト数から考えてこの推論サーバを3-4台は並べたいです。各AZに振り分けて並べ、上にALBを置いてロードバランスしてみましょう。
またEC2側は障害対策やスポットインスタンスの期限切れなども考え、Auto Scalingでくるむことにしました。ここらへんまでは至ってよくあるAWS的な構成かと思います。

考えたのはここからです。

このAPIサーバはnginxで受けているのでLambdaから叩くには80ポートを開けておく必要があります。ただ、Lambdaが設置されるIPレンジは広く、不定期に変更するため「Lambdaからのリクエストのみ受け付ける」ということができず、80ポートを公開している状態になっています。こうするとSPAM的なアクセスも結構くるようになるので、セキュリティ的にこの80ポートを塞いで、Lambdaからのみアクセスされるようにしたい

まず、これを解決するために、LambdaをVPCの中に入れ、同じセキュリティグループをつけて、同セキュリティグループ内の通信をすべて許可してみましょう。これで外からは叩かれずにLambdaからのみアクセスされるようになります。

次にAuto Scalingを組んでいるので、Golden AMIを通じて自動でサーバーは起動してALBにアタッチされるわけですが、その前にウォームアップ処理を追加して、モデルがロードされた状態でALBにアタッチしないと、初回リクエストに時間がかかってしまいます。
確かAuto Scalingには「ライフサイクルフック」という機能があって、EC2の起動や終了時に処理を挟むことができたはずです。私も昔ブログを書いた記憶があります。

に、2014年。。。絶対色々変わってるでしょうね。まあ、でもできるはずです。

と、ここまで頭で考えて、これを実装しなければいけないことに気が付きました。

なかなかのボリュームですので、ここにメモを残しながらやってみましょう。

やってみた

Auto Scaling自体はただ画面の指示にしたがってAMI、インスタンス、セキュリティグループなどを選んで起動設定を作り、Auto Scalingグループで台数を指定すればOKなので、特に迷いません。起動設定、Auto ScalingグループはそれぞれマネージメントコンソールのEC2のページにあります。

ライフサイクルフックの仕組みを理解する

2014から何がどれくらい変わっているのか、ブログを漁っていたところ最近ライフサイクルフックを構築した記事を見つけました。

このユースケースはスケールイン、つまりサーバが減る時に、ログを取ってから潰すというものですが、基本のライフサイクルフックの仕組みは変わりません。つまり

  • Auto Scaling グループにライフサイクルフックを設定して、インスタンスの起動時にイベントを発出
  • Amazon EventBridge(旧CloudWatch Event)でAuto Scalingのイベントを取得してLambdaを起動。この時インスタンスIDをLambdaに渡す。
  • Lambdaにて対象のインスタンスIDに対してboto3でPublicDNSを取得し、ウォームアップ処理を行う

という順番ですね。

Amazon EventBridgeの構築

まず、EventBridgeを作っていきましょう。

EventBridgeのページからルールを作成するボタンをクリックします。

名前と説明を入れ、「イベントパターンを持つルール」を選択して次へ。

イベントソースを「AWSイベント」とし、

EC2のLaunch(起動) Lifecycle Action時のサンプルイベントをメモ帳などにコピーしておきます。この内容がLambdaに投げられることになります。

イベントパターンにAuto Scalingの起動時のライフサイクルアクションを指定します。

ターゲットにウォームアップ処理を行うLambdaを指定して作成すればOKです。

これでEventBridgeはOKです。

Lambdaの改修

Lambdaのウォームアップ処理はrequestsモジュールを使ってAPIサーバを叩いています。APIサーバはuWSGIにてプロセスをコア数と同じ8まで増やしているので、モデルはそのコネクション分アクセスしてロードしておかなくてはいかないので、スレッドプールを使って同時にリクエストを投げるようにしています。

import json
import time
import os
import requests
from concurrent.futures import ThreadPoolExecutor

from aws_xray_sdk.core import patch
patch(['requests'])

def lambda_handler(event, context):
        
    endpoint_name = "http://ec2publicDNSname.com/predict"

    # Request.
    start = time.time()
    i = 0
    
    while i < 1:
        
        process_list = []
        with ThreadPoolExecutor() as executor:
            for i in range(15):
                process_list.append(executor.submit(child_process_func, endpoint_name))
            for process in process_list:
                print(process.result())
            
    t = time.time() - start
    

    return {
        'statusCode': 200,
        'body': json.dumps(t)
    }
    
def child_process_func(endpoint_name):
    i = 0
    filename = 'test/test.jpg'

    
    while i < 5:
        payload = { "image": open(filename, 'rb') }
        start = time.time()
        response = requests.post(
            endpoint_name,
            files = payload)
        t = time.time() - start
        print(str(os.getpid()) + ": " + str(t))
        i += 1
        
        response_dict = json.dumps(response.json(), indent=2)

        #print(response_dict)
    
    return os.getpid()

さきほどのEventBridgeでのサンプルイベントを確認してみましょう。

{
  "version": "0",
  "id": "6a7e8feb-b491-4cf7-a9f1-bf3703467718",
  "detail-type": "EC2 Instance-launch Lifecycle Action",
  "source": "aws.autoscaling",
  "account": "123456789012",
  "time": "2015-12-22T18:43:48Z",
  "region": "us-east-1",
  "resources": ["arn:aws:autoscaling:us-east-1:123456789012:autoScalingGroup:59fcbb81-bd02-485d-80ce-563ef5b237bf:autoScalingGroupName/sampleASG"],
  "detail": {
    "LifecycleActionToken": "c613620e-07e2-4ed2-a9e2-ef8258911ade",
    "AutoScalingGroupName": "sampleASG",
    "LifecycleHookName": "SampleLifecycleHook-12345",
    "EC2InstanceId": "i-12345678",
    "LifecycleTransition": "autoscaling:EC2_INSTANCE_LAUNCHING"
  }
}

起動しているインスタンスのIDは['detail']['EC2InstanceId']にあるようですので、ここを元にboto3でPublicDNSNameを取得してendpoint_nameを動的に取得するように改変します。

import json
import time
import os
import requests
from concurrent.futures import ThreadPoolExecutor
import boto3

from aws_xray_sdk.core import patch
patch(['requests'])

def lambda_handler(event, context):
    
    instanceID = [event["detail"]["EC2InstanceId"]]
    
    client = boto3.client('ec2')
    
    instance_public_ip = client.describe_instances(InstanceIds=instanceID)
    
    ec2_public_ip_address = instance_public_ip['Reservations'][0]['Instances'][0]['PublicDnsName']
    
    endpoint_name = "http://" + ec2_public_ip_address + "/predict"   

    # Request.
    start = time.time()
    i = 0
    
    while i < 1:
        
        process_list = []
        with ThreadPoolExecutor() as executor:
            for i in range(15):
                process_list.append(executor.submit(child_process_func, endpoint_name))
            for process in process_list:
                print(process.result())
            
    t = time.time() - start
    

    return {
        'statusCode': 200,
        'body': json.dumps(t)
    }
    
def child_process_func(endpoint_name):
    i = 0
    filename = 'test/test.jpg'

    
    while i < 5:
        payload = { "image": open(filename, 'rb') }
        start = time.time()
        response = requests.post(
            endpoint_name,
            files = payload)
        t = time.time() - start
        print(str(os.getpid()) + ": " + str(t))
        i += 1
        
        response_dict = json.dumps(response.json(), indent=2)

        #print(response_dict)
    
    return os.getpid()

しかし、このままLambdaを叩いてもタイムアウトまで何も返ってきません。

VPC Lambdaからboto3を通じてEC2を操作する場合、VPC Endpointを作成してルーティングを通してあげる必要があります。通信を一旦外に出してまた内部に戻してくるのはなんだか複雑な気持ちになりますが、とりあえず受け入れることにします。

IAM Roleの追加とVPC Endpointの作成

まず、Lambdaの実行でEC2の状態を取得できるようにReadonlyのポリシーを実行ロールに追加します。

次に、VPC Endpointを作成します。マネージメントコンソールのVPCページよりエンドポイントを作成します。

対象のAWSサービスはEC2(com.amazonaws.ap-northeast-1.ec2)を選択します。

あとはエンドポイントをつけるVPCとサブネットを選択して作成すれば適用されます。

これでboto3からEC2の情報を取得することができるようになりました。

ただ、ここにはもう一つ落とし穴があります。ライフサイクルフックのlaunch-actionは文字通り「起動」した瞬間に通知されます。インスタンスはステータスチェックをクリアして、StatusがEnabledにならないと使えないので、Lambda側でインスタンスがイニシャライズされるまで待つ必要があります。boto3にはwaiterというオブジェクトがあるので、それを利用して

waiter = client.get_waiter('instance_status_ok')
waiter.wait(InstanceIds=instanceID)

のように書けばインスタンスがイニシャライズされるまでこの行で待ち、StatusがEnbabledになったら次の行に進むようになります。こちらを先程のロジックに組み込みます。インスタンスの起動は3-4分なので、Lambdaの実行時間を5分くらいまで伸ばすのを忘れないようにしてください

import json
import time
import os
import requests
from concurrent.futures import ThreadPoolExecutor
import boto3

from aws_xray_sdk.core import patch
patch(['requests'])

def lambda_handler(event, context):
    
    instanceID = [event["detail"]["EC2InstanceId"]]
    
    client = boto3.client('ec2')
    
    instance_public_ip = client.describe_instances(InstanceIds=instanceID)
    
    ec2_public_ip_address = instance_public_ip['Reservations'][0]['Instances'][0]['PublicDnsName']
    
    endpoint_name = "http://" + ec2_public_ip_address + "/predict"
    print("endpoint_name")
    print(endpoint_name)
    
    # wait until target instances states is enabled
    waiter = client.get_waiter('instance_status_ok')
    print("waiting until status is OK...")
    waiter.wait(InstanceIds=instanceID)
    print("Instance: " + instanceID[0] + "is green status.")

    # Request.
    start = time.time()
    i = 0
    
    while i < 1:
        
        process_list = []
        with ThreadPoolExecutor() as executor:
            for i in range(15):
                process_list.append(executor.submit(child_process_func, endpoint_name))
            for process in process_list:
                print(process.result())
            
    t = time.time() - start
    

    return {
        'statusCode': 200,
        'body': json.dumps(t)
    }
    
def child_process_func(endpoint_name):
    i = 0
    filename = 'test/test.jpg'

    
    while i < 5:
        payload = { "image": open(filename, 'rb') }
        start = time.time()
        response = requests.post(
            endpoint_name,
            files = payload)
        t = time.time() - start
        print(str(os.getpid()) + ": " + str(t))
        i += 1
        
        response_dict = json.dumps(response.json(), indent=2)

        #print(response_dict)
    
    return os.getpid()

これでLambda側の準備はOKです。

ライフサイクルフックの設定

最後にAuto Scalingにライフサイクルフックの設定を追加します。Auto Scaling グループの「インスタンスの管理」から「ライフサイクルフックの作成」をクリックします。

名前を入れたら「インスタンス起動」を選択し、ハートビートを設定します。イニシャライズを待つのにLambdaで5分待っているので、同じ様にハートビートも5分(300秒)を設定します。結果はABANDON(放棄)でいいと思います。

これで設定は終了です!お疲れさまでした。

テスト

動作確認のためにテストします。Auto Scalingを0台から1台に増やしてみます。

インスタンスが立ち上がり、Statusが「Wait」に変わります。

インスタンスの状況を見てみると、まだ立ち上がり中で使える状態まではいっていません。

しかし起動はしているのでEventBridgeからLambdaに通知はいっています。Lambdaのログを見てみると、仕込んでおいたwaiterで止まっているのがわかります。

やがてステータスチェックが終わり、インスタンスがオールグリーンになると

waiterが切れてウォームアップ処理が始まりました。ざっと3分くらい待ちました。

ウォームアップ処理が終わるとAuto Scalingグループのステータスがハートビート終了後に「In Service」になり、準備完了となります。

最後にALBへのアタッチが終われば完了です。

バッチリです!!

まとめ

以上、今回はAuto ScalingのライフサイクルフックでVPC Lambdaからウォームアップ処理を行ってみました。
ユースケースとしてはいくつかの課題が重なっているので少し複雑ですが、要件は満たせたので安心しました。