Vertex AIのカスタムコンテナでバッチ推論をする

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

おはこんハロチャオ~!なにもんなんじゃ?じょんすみすです。

当エントリー『クラスメソッド 機械学習チーム アドベントカレンダー 2022』は3日目となります。

Vertex AIでカスタムコンテナを使って学習・推論を行う方法を以前紹介しました。

上記記事では、カスタムコンテを使ってエンドポイントを作成して推論を行っていましたが、 今回はバッチ推論を行う場合の方法を紹介していきます。

バッチ推論のための準備

カスタムコンテナを使ったバッチ推論をする場合には以下の2点を気にしておく必要があります。

  • AIP_HEALTH_ROUTE, AIP_PREDICT_ROUTEの設定
  • 入出力のフォーマットへの対応

これらを順にみてみましょう。

AIP_HEALTH_ROUTE, AIP_PREDICT_ROUTEの設定

エンドポイントを作成する際には必須ではありませんでしたが、 バッチ推論を行うにはModel Registryに登録するコンテナにこの値が設定されている必要があります。

この設定は、トレーニングジョブでの設定やモデルのアップロード時に行います。

トレーニングジョブの結果からモデルを直接Model Registryに登録する際は以下のようにジョブ作成指示に指定します。

job = aiplatform.CustomContainerTrainingJob(
    display_name="linear-regression",
    container_uri=TRAIN_CONTAINER_URI,
    command=["python", "train.py", "--target", "median_house_value"],
    model_serving_container_command=["python", "serve.py"],
    model_serving_container_image_uri=MODEL_CONTAINER_URI,
    # 以下2つの設定を追加
    model_serving_container_predict_route='/predict',
    model_serving_container_health_route='/health'
)

モデルをアップロードする場合もほぼ同様に引数で指定します。

model = aiplatform.Model.upload(
        display_name="linear-regression",
        serving_container_image_uri=MODEL_CONTAINER_URI,n
        artifact_uri="gs://<モデルファイルのあるCloud Storageのパス>",
        serving_container_command=["python", "serve.py"],
        # 以下2つの設定を追加
        serving_container_predict_route='/predict',
        serving_container_health_route='/health'
)

また、推論を行うコンテナ内でこの値を利用できるようにしておきます。

AIP_HEALTH_ROUTE = os.environ.get('AIP_HEALTH_ROUTE', '/health')
AIP_PREDICT_ROUTE = os.environ.get('AIP_PREDICT_ROUTE', '/predict')

@app.route(AIP_HEALTH_ROUTE, methods=['GET'])
def ping():
    ... # 死活監用の関数


@app.route(AIP_PREDICT_ROUTE, methods=['POST'])
def transformation():
    ... # 推論処理の実装

これで、バッチ推論ジョブとして実行可能な推論用コンテナとなりました。

入出力のフォーマットへの対応

バッチ推論ではCloud Storage上にあるファイルかBigQueryを入出力として扱えます。

Cloud Storage上のファイルを扱う場合はいくつかのフォーマットに対応してしてます。 それらはエンドポイントにリクエストを送る際の"instances"の中身として扱われます。

対応しているファイルは以下のものになっており、それらの渡され方を見ていきましょう。

  • JSONL
  • CSV
  • TFRecord(無圧縮, GZIP圧縮)
  • Cloud Storageのファイルリスト

JSONL形式は1行につき1つのデータをJSON形式で記述したものとなります。 そのため、"instances"の中身はそのまま渡されます。

例えば以下のような内容のファイルがあるとします。

[1, 2, 3]
[4, 5, 6]

この場合推論コンテナが受け取るのは以下のような形式になります。

{
    "instances": [
        [1, 2, 3]
        [4, 5, 6]
    ]
}

key-value形式の場合も同様です。

{"a": 1, "b": 2, "c": 3}
{"a": 4, "b": 5, "c": 6}

の場合、

{
    "instances": [
        {"a": 1, "b": 2, "c": 3}
        {"a": 4, "b": 5, "c": 6}
    ]
}

既に推論用コンテナを作成済みの時、ファイル側でリクエストに併せることも可能な形式ですね。

CSVの場合、ファイルのヘッダが必要となります。 ただし、推論コンテナに渡されるものはkey-value形式ではなく値の配列形式となります。 例えば以下のようなデータがあるとします。

a,b,c
1,2,3
4,5,6

必須であるヘッダには「a,b,c」が該当しますが、リクエストとして渡されるのは以下のようになります。

{
    "instances": [
        [1, 2, 3]
        [4, 5, 6]
    ]
}

TFRecordやファイルリストを記載したファイルの場合、 それらの中身のバイナリをBase64エンコードしたものが、"b64"をキーとして渡されます。

{
    "instances": [
        {"b64": "<バイナリファイルのBase64>"},
        {"b64": "<バイナリファイルのBase64>"}
    ]
}

最後にBigQueryをデータソースとする場合です。 こちらはCSVファイルと同様、対象となる列のkey-valueではなく値のみの配列となります。

利用するデータソースやファイルに応じて、 推論を行うコンテナのAIP_PREDICT_ROUTEに対応した関数でこれらを受け取って処理できる仕組みを実装しておく必要があります。

なお、出力に関しては"predictions"の中身がJSONL形式でそのまま出力されます。

バッチ推論のエラー確認

バッチ推論では、以下の操作をGoogle Cloud側で自動的に行ってくれます。

  • 内部でのエンドポイントの立ち上げ
  • 入力元から上記のような変換をリクエストを送信
  • レスポンスを受け取り出力先へ書き込む

これらは内部で行われているため、ユーザ側からは利用してるリソースに関する情報を見ることができません。

処理がうまくいってる分にはいいのですが、エラーが発生した際などは利用しているリソース側のログなどを見ることができないため、困ることがあります。 推論コンテナ内での処理に問題がある際には以下のような対応方法が考えられます。

  • エンドポイントを立ち上げてそこに対してリクエストを投げる
  • エラーが起きた際にはその内容を出力として返すようにする

入力に使っているファイルに問題がありそうな場合などは、以下のように後者の実装をして同じデータで再度実行してみると、出力としてエラー内容が確認できます。

@app.route(AIP_PREDICT_ROUTE, methods=['POST'])
def transformation():
    try:
        ... # 推論処理の実装
    except Exception as error:
        # エラーが発生した際はその内容をレスポンスとして返す
        response = {"error": f"{error}"}
        return flask.jsonify({'predictions': [response]})

また、前者のような一時的にエンドポイントを立てる場合には、常時課金が発生しますので、 確認後に削除し忘れないように気を付けましょう

おわりに

今回は、Vertex AIでカスタムコンテナを使ってバッチ推論を実施する方法を紹介しました。

エンドポイントを立てる場合にプラスアルファでバッチ推論に対応するための実装が必要になる部分もありますが、 課金の発生が処理中のみとなりますので、推論処理の使い方としてバッチ処理をする場合はこの方法も検討してみてください。