Bedrock で画像を生成する仕組みを LINE ボットに組み込んでみた

生成系 AI サービスである Amazon Bedrock を LINE ボットから呼び出せるようにしたのですが、連携のために S3 や Cloud Watch Logs といった他のサービスへの対応が必要になった顛末記です。
2023.10.12

こんにちは、高崎@アノテーション です。

はじめに

生成系 AI サービスである Amazon Bedrock(以下 Bedrock)を手持ちの LINE ボット環境に組み込んで呼び出すようにしてみました。

Bedrock のモデルは画像生成を行う Stability AI を使用し、LINE ボットのメッセージから入力テキストをパラメータに渡して画像を生成してもらい、LINE ボットの応答に画像メッセージとして送信するような仕組みを考えております。

ベース

ソースについては例によって例のごとく、前回 LINE ボットで引用するようにした以下の記事をベースに組み込んでおります。

概要

実装方針

ベースが cdk を使った IaC 構造で組んでいるので、Bedrock を使えるようにする方法や、Lambda から Bedrock をどうやって呼び出すかを検討する必要がありますが、それぞれ下記のサイトや AWS のドキュメントを参考にしました。

  • cdk に組み込むやり方

  • Lambda から呼び出すやり方

画像の扱い方

ボットから画像を送信する場合、LINE の画像メッセージの説明 によると、画像は URL を指定する必要があるので、生成した画像を URL で取得出来る何処かしらに格納して、メッセージに組み込む必要があります。

そこで今回は、S3 に画像を格納することにしました。

S3 への格納ですが、当初はサンプルなので ACL を許可してパブリック接続で、と軽く考えていたのですが、S3 のアクセスポリシーのガイドライン を確認すると、

Amazon S3 の最新のユースケースの大部分では ACL を使用する必要がなくなっています。 オブジェクトごとに個別に制御する必要がある通常ではない状況を除き、ACL は無効にしておくことをお勧めします。

…とあったので、ACL はそのままに S3 に格納してから署名付き URL を取得して LINE に送るようにして、S3 に保存するファイル名は仮として送られてくるquoteTokenを使ってこれを使うことにします。

LINE ボットから Bedrock へ呼び出す方法について

コマンド I/F を追加し「ask:」が来ると Bedrock へ呼び出すようにしました。

コマンド 説明 フォーマット 応答
list 一覧表示 list 以降は何の文字が入っても良い 成功時「通し番号 : メモ」を連続出力
失敗時 エラーレスポンス
regist メモ登録 regist: 以降の文字をメモとして登録する 成功時「〜 の内容を登録しました」
失敗時 エラーレスポンス
delete 削除 delete:通し番号 で指定 成功時「番号 の削除が完了しました」
失敗時 エラーレスポンス
ask 画像生成 ask:依頼コマンド で指定 成功時 画像メッセージを生成して送信
失敗時 エラーレスポンス

実装

前項を踏まえて、実装です。

cdk スタック

cdk スタック定義ソースは下記になります。
※前回からの変更箇所をハイライトしております。

import { Duration, RemovalPolicy, Stack, StackProps } from "aws-cdk-lib";
import { Construct } from "constructs";
import { LayerVersion, Function, Runtime, Code } from "aws-cdk-lib/aws-lambda";
import { RestApi, LambdaIntegration } from "aws-cdk-lib/aws-apigateway";
import { Table, AttributeType } from "aws-cdk-lib/aws-dynamodb";
import { Secret } from "aws-cdk-lib/aws-secretsmanager";
import { Bucket } from "aws-cdk-lib/aws-s3"
import { ScopedAws } from "aws-cdk-lib";
import { PolicyStatement, Role, Effect, ServicePrincipal } from "aws-cdk-lib/aws-iam";

export class LineBotTestStack extends Stack {
  constructor(scope: Construct, id: string, props?: StackProps) {
    super(scope, id, props);

    // example resource
    // Lambda 関数の作成
    const roleMemoBot = new Role(this, "LineMemoRole", {
      assumedBy: new ServicePrincipal("lambda.amazonaws.com"),
    });
    const lambdaLayer = LayerVersion.fromLayerVersionArn(
      this,
      "lambdaLayer",
      "arn:aws:lambda:ap-northeast-1:133490724326:layer:AWS-Parameters-and-Secrets-Lambda-Extension:4"
    );
    const lambdaMemoBot = new Function(this, "LineMemoBot", {
      runtime: Runtime.NODEJS_18_X,
      handler: "index.handler",
      code: Code.fromAsset("src/lambda"),
      layers: [lambdaLayer],
      environment: {
        PARAMETERS_SECRETS_EXTENSION_HTTP_PORT: "2773",
        PARAMETERS_SECRETS_EXTENSION_CACHE_ENABLED: "true",
        SECRET_ID: "LineAccessInformation",
        TABLE_MAXIMUM_NUMBER_OF_RECORD: "5",
      },
      role: roleMemoBot,
      timeout: Duration.seconds(120),
    });

    // DynamoDB の作成
    const dynamoMemoBot = new Table(this, "dynamoMemoTable", {
      tableName: "LineMemoBot_memo",
      partitionKey: { name: "lineuserid", type: AttributeType.STRING },
      sortKey: { name: "id", type: AttributeType.NUMBER },
    });
    // 作ったテーブル名を Lamnbda の環境変数へセット
    lambdaMemoBot.addEnvironment("TABLE_NAME", dynamoMemoBot.tableName);

    // S3 の作成
    const s3MemoBot = new Bucket(this, "s3MemoBucket", {
      removalPolicy: RemovalPolicy.DESTROY,
      autoDeleteObjects: true,
    });
    // 作ったバケットを Lamnbda の環境変数へセット
    lambdaMemoBot.addEnvironment("BUCKET_NAME", s3MemoBot.bucketName);

    // 権限付与
    // Lambda -> Secrets Manager
    const { region, accountId } = new ScopedAws(this);
    const stringSecretName = "LineAccessInformation";
    const stringSecretArn = `arn:aws:secretsmanager:${region}:${accountId}:secret:${stringSecretName}-XXXXXX`;
    const smResource = Secret.fromSecretCompleteArn(
      this,
      "SecretsManager",
      stringSecretArn
    );
    smResource.grantRead(lambdaMemoBot);
    // Lambda -> DynamoDB
    dynamoMemoBot.grantReadWriteData(lambdaMemoBot);
    // Lambda -> S3
    /*s3MemoBot.grantPutAcl(lambdaMemoBot);*/
    s3MemoBot.grantReadWrite(lambdaMemoBot);
    // Lambda -> Bedrock
    const policyBedrock = new PolicyStatement({
      effect: Effect.ALLOW,
      actions: ["bedrock:InvokeModel"],
      resources: ["*"],
    });
    // Lambda -> Cloud Watch Logs
    lambdaMemoBot.addToRolePolicy(policyBedrock);
    const policyCloudWatch = new PolicyStatement({
      effect: Effect.ALLOW,
      actions: ["logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents",],
      resources: ["*"],
    });
    lambdaMemoBot.addToRolePolicy(policyCloudWatch);
    lambdaMemoBot.addEnvironment("BEDROCK_PARAM_CFG_SCALE", "10");
    lambdaMemoBot.addEnvironment("BEDROCK_PARAM_SEED", "0");
    lambdaMemoBot.addEnvironment("BEDROCK_PARAM_STEPS", "50");

    // API Gateway の作成
    const api = new RestApi(this, "LineMemoApi", {
      restApiName: "LineMemoApi",
    });
    // proxy ありで API Gateway に渡すインテグレーションを作成
    const lambdaInteg = new LambdaIntegration(lambdaMemoBot, {
      proxy: true,
    });
    // API Gateway の POST イベントと Lambda との紐付け
    api.root.addMethod("POST", lambdaInteg);
  }
}

変更ポイントは下記になります。

  • Lambda へロールを追加(Bedrock のアクセス権限を追加するために必要)
  • Bedrock へのアクセス権限をロールへ追加
  • cdk 生成時にデフォルトで自動で追加されていた Cloud Watch Logs のアクセス権限をマニュアルで追加
  • S3 バケットの作成と Lambda からのアクセス権限の追加
  • Lambda へ環境変数を追加(後述の Bedrock へ指定するパラメータを変更するため)

Lambda ソース

Lambda は下記になります。

Lambda ソース(400 行近いのでたたんでいますが変更箇所をハイライトしております)
import { Client, validateSignature,/* WebhookRequestBody,*/ LINE_SIGNATURE_HTTP_HEADER_NAME } from "@line/bot-sdk";
// import { Message } from "@line/bot-sdk/lib/types";
import { APIGatewayProxyResult, APIGatewayProxyEvent, Context, } from "aws-lambda";
import axios from "axios";
import { DynamoDBClient } from "@aws-sdk/client-dynamodb";
import { DynamoDBDocumentClient, QueryCommand, PutCommand, DeleteCommand, } from "@aws-sdk/lib-dynamodb";
import { BedrockRuntimeClient, InvokeModelCommand } from "@aws-sdk/client-bedrock-runtime";
import { S3Client, PutObjectCommand, GetObjectCommand } from "@aws-sdk/client-s3";
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";

// Secrets Manager から取得するための諸々
const cacheEnabled = process.env.PARAMETERS_SECRETS_EXTENSION_CACHE_ENABLED || "false";
const requestCache: boolean = JSON.parse(cacheEnabled.toLowerCase());
const httpPort = process.env.PARAMETERS_SECRETS_EXTENSION_HTTP_PORT || "2773";
const requestSecretId = process.env.SECRET_ID || "MySecretId";
const requestEndpoint = `http://localhost:${httpPort}/secretsmanager/get?secretId=${requestSecretId}`;
const requestOptions = {
  headers: {
    "X-Aws-Parameters-Secrets-Token": requestCache ? process.env.AWS_SESSION_TOKEN : "",
  },
};

// DynamoDB へアクセスするための諸々
const numMaxRecord: number = Number(process.env.TABLE_MAXIMUM_NUMBER_OF_RECORD) || 5;
const stringTableName: string = process.env.TABLE_NAME || "linememo";
const clientDB = DynamoDBDocumentClient.from( new DynamoDBClient({ region: process.env.AWS_REGION }));

// Bedrock へアクセスするための諸々
const clientBedrock = new BedrockRuntimeClient({ region: "us-east-1" });
const paramCfgScale: number = Number(process.env.BEDROCK_PARAM_CFG_SCALE) || 10;
const paramSeed: number = Number(process.env.BEDROCK_PARAM_SEED) || 0;
const paramSteps: number = Number(process.env.BEDROCK_PARAM_STEPS) || 50;

// S3 へアクセスするための諸々
const clientS3 = new S3Client({ region: process.env.AWS_REGION });
const s3BucketName: string = process.env.BUCKET_NAME || "MyBucketName";

// レスポンス結果(200/成功、500/失敗)を固定で設定
const resultError: APIGatewayProxyResult = {
  statusCode: 500,
  body: "Error",
};
const resultOK: APIGatewayProxyResult = {
  statusCode: 200,
  body: "OK",
};

// DynamoDB から引数のユーザに関する最後に登録した通し番号を取得する
async function getUserTableLastCount(stringUser: string): Promise<number> {
  // 逆順で1つだけ取得
  const command = new QueryCommand({
    TableName: stringTableName,
    KeyConditionExpression: "lineuserid = :userid",
    ExpressionAttributeValues: { ":userid": stringUser },
    ScanIndexForward: false,
    Limit: 1,
  });
  try {
    const response = await clientDB.send(command);
    if (response.Count !== undefined) {
      // 一つもなかった場合は Items が空配列なのでゼロを返す
      if (response.Count === 0) {
        return 0;
      }
      // 値がある場合は id に何らかの値が入っているのでそれを返す
      if (response.Items !== undefined && response.Items[0].id !== undefined) {
        console.log("Items id : ", response.Items[0].id);
        return Number(response.Items[0].id);
      }
    }
    return -1;
  } catch (error) {
    console.error("getUserTableLastCount : ", error);
  }
  return -1;
}

// DynamoDB から引数のユーザに関するデータを全件取得する
interface getUserTableProc {
  stringUser: string;
  numCount?: number;
  boolForward?: boolean;
}
async function getUserTable( argProc: getUserTableProc): Promise<Record<string, any>[] | undefined> {
  // クエリーを発行して取得する
  const command = new QueryCommand({
    TableName: stringTableName,
    KeyConditionExpression: "#lineuserid = :userid",
    ExpressionAttributeNames: { "#lineuserid": "lineuserid" },
    ExpressionAttributeValues: { ":userid": argProc.stringUser },
    Limit: argProc.numCount,
    ScanIndexForward: argProc.boolForward,
  });
  try {
    const { Items: items = [] } = await clientDB.send(command);
    console.log(JSON.stringify(items));
    return items;
  } catch (error) {
    console.error("getUserTable : ", error);
    return undefined;
  }
}

// 共通 I/F 定義
interface ifCommand {
  result: boolean;
  text: string;
};
// Bedrock からデータを取得
const getAutoImage = async (stringParam: string): Promise<ifCommand> => {
  const invokeBedrock = new InvokeModelCommand({
    modelId: "stability.stable-diffusion-xl-v0",
    body: JSON.stringify({
      text_prompts: [{ text: stringParam }],
      cfg_scale: paramCfgScale,
      seed: paramSeed,
      steps: paramSteps,
    }),
    accept: "application/json",
    contentType: "application/json",
  });
  console.log(invokeBedrock);
  try {
    const responseBedrock = await clientBedrock.send(invokeBedrock);
    const jsonBedrock = JSON.parse(Buffer.from(responseBedrock.body).toString("utf-8"));
    console.log(jsonBedrock);
    if (jsonBedrock.result !== "success") {
      return {
        result: false,
        text: "Bedrock result was not success.",
      };
    }
    return {
      result: true,
      text: jsonBedrock.artifacts[0].base64,
    };
  } catch (e) {
    // コンソールにエラーを出しておく
    console.error(e);
    if (e instanceof Error) {
      return { result: false, text: `Bedrock 要求時にエラーが発生しました : ${e.message}` };
    }
    return { result: false, text: "Bedrock 要求時に失敗しました" };
  }
}

// 取得したデータを S3 へ保存する
const saveImageToS3 = async (strImageInput: string, quoteToken: string): Promise<ifCommand> => {
  console.log(strImageInput);
  const rawText = strImageInput.replace(/^data:\w+\/\w+;base64,/, "");
  const rawData = Buffer.from(rawText, "base64");
  const inputCommand = {
    /*ACL: "public-read",*/
    Body: rawData,
    Bucket: s3BucketName,
    Key: `${quoteToken}.png`,
    ContentType: "image/png",
  };
  const commandS3 = new PutObjectCommand(inputCommand);
  try {
    const responseS3 = await clientS3.send(commandS3);
    return { result: true, text: `${quoteToken}.png` };
  } catch (e) {
    // コンソールにエラーを出しておく
    console.error(e);
    if (e instanceof Error) {
      return { result: false, text: `S3 保存時にエラーが発生しました : ${e.message}` };
    }
    return { result: false, text: "S3 への保存に失敗しました" };
  }
} 
// S3 の URL を取得
const getS3URL = async (quoteToken: string): Promise<string> => {
  const getCommand = new GetObjectCommand({
    Bucket: s3BucketName,
    Key: `${quoteToken}.png`, 
  })
  return getSignedUrl(clientS3, getCommand);
}

// ask コマンドのケース
const askCommandUseCase = async (requestText: string, quoteToken: string): Promise<any> => {
  const stringCommand = requestText.replace("ask:", "");
  console.log("ask image : ", stringCommand);
  // Bedrock からデータを取得
  const { result: getAutoImageResult, text: getAutoImageText } = await getAutoImage(stringCommand);
  if (getAutoImageResult === false) {
    return { type: "text", text: getAutoImageText, quoteToken: quoteToken, };
  }
  // 取得したデータを S3 へ保存
  const { result: saveImageToS3Result, text: saveImageToS3Text } = await saveImageToS3(getAutoImageText, quoteToken); 
  if (saveImageToS3Result === false) {
    return { type: "text", text: saveImageToS3Text, quoteToken: quoteToken };
  }
  // S3 の URL を取得
  const stringS3URL = await getS3URL(quoteToken);
  console.log(stringS3URL);
  // 応答メッセージを返却
  return {
    type: "image",
    originalContentUrl: stringS3URL,
    previewImageUrl: stringS3URL,
  };
}

export const handler = async ( eventLambda: APIGatewayProxyEvent, contextLambda: Context): Promise<APIGatewayProxyResult> => {
  console.log(JSON.stringify(eventLambda));
  // Secrets Manager から値を取得
  const responseSM = await axios.get(requestEndpoint, requestOptions);
  const jsonSecret = JSON.parse(responseSM.data["SecretString"]);
  const clientLine = new Client({
    channelAccessToken: jsonSecret.ACCESS_TOKEN!,
    channelSecret: jsonSecret.CHANNEL_SECRET,
  });

  const stringSignature = eventLambda.headers[LINE_SIGNATURE_HTTP_HEADER_NAME];
  // Line の署名認証
  if (!validateSignature( eventLambda.body!, clientLine.config.channelSecret!, stringSignature!)) {
    // 署名検証がエラーの場合はログを出してエラー終了
    console.log("署名認証エラー", stringSignature!);
    return resultError;
  }
  // 文面の解析
  const bodyRequest = JSON.parse(eventLambda.body!);
  if (typeof bodyRequest.events[0] === "undefined") {
    // LINE Developer による Webhook の検証は events が空配列の body で来るのでその場合は 200 を返す
    console.log("Webhook inspection");
    return resultOK;
  }
  if (bodyRequest.events[0].type !== "message" || bodyRequest.events[0].message.type !== "text") {
    // text ではない場合は終了する
    console.log("本文がテキストではない", bodyRequest);
    return resultError;
  } else {
    // 要求メッセージを取得
    const requestText: string = bodyRequest.events[0].message.text;
    // 要求元 Line UserID を取得
    const requestUserId: string = bodyRequest.events[0].source.userId || "";
    // 応答メッセージが動的になるので初期化
    var messageReply: any[] = [];
    // 引用トークンを取得
    const quoteToken = bodyRequest.events[0].message.quoteToken;

    try {
      if (requestText.startsWith("regist:")) {
        const stringMemo = requestText.replace("regist:", "");
        // 保存されているデータで id の最大値を取得
        const numLastId: number = await getUserTableLastCount(requestUserId);
        if (numLastId === -1) {
          // 失敗したら登録せず throw して終了
          throw new Error("最大値取得に失敗");
        }
        // PutItem 発行
        await clientDB.send(
          new PutCommand({
            TableName: stringTableName,
            Item: {
              lineuserid: requestUserId,
              id: numLastId + 1,
              RegistTime: new Date().getTime(),
              memo: stringMemo,
            },
          })
        );
        // UserId のデータを全件取得
        const items = (await getUserTable({
          stringUser: requestUserId
        })) || [];
        // 5件以上あるものは削除(基本的には1件)
        items.splice(-numMaxRecord);
        await Promise.all(
          items.map((element) => {
            const commandDelete = new DeleteCommand({
              TableName: stringTableName,
              Key: {
                lineuserid: element["lineuserid"],
                id: element["id"],
              },
            });
            clientDB.send(commandDelete);
          })
        );
        // 応答メッセージを作成
        messageReply.push({
          type: "text",
          text: stringMemo + " の登録が完了しました",
          quoteToken: quoteToken,
        });
      } else if (requestText.startsWith("list")) {
        // UserId のデータを全件取得
        // 新しい順なので降順で取得する
        const items = (await getUserTable({
          stringUser: requestUserId,
          numCount: numMaxRecord,
          boolForward: false,
        })) || [];
        // 応答メッセージへセット
        items.forEach((responseItem: any, index: number) => {
          if (responseItem.memo !== undefined) {
            messageReply.push({
              type: "text",
              text: responseItem.id + " : " + responseItem.memo,
              quoteToken: (index === 0) ? quoteToken : "",
            });
          }
        });
        if (messageReply.length === 0) {
          // 一件も無い場合はレスポンスが空なので応答メッセージを別途設定
          messageReply.push({
            type: "text",
            text: "一覧が存在しません",
            quoteToken: quoteToken,
          });
        }
      } else if (requestText.startsWith("delete:")) {
        const stringIndex = requestText.replace("delete:", "");
        console.log("lineuserid : ", requestUserId, "id : ", stringIndex);
        // DeleteItem 発行
        const responseDelete = await clientDB.send(
          new DeleteCommand({
            TableName: stringTableName,
            Key: {
              lineuserid: requestUserId,
              id: Number(stringIndex),
            },
          })
        );
        console.log(JSON.stringify(responseDelete));
        // 応答メッセージをセット
        messageReply.push({
          type: "text",
          text: stringIndex + " の削除処理が完了しました",
          quoteToken: quoteToken,
        });
      } else if (requestText.startsWith("ask:")) {
        // ask コマンド用関数を呼び出して応答を messageReply へ push
        const replyAsk = await askCommandUseCase(requestText, quoteToken);
        console.log(replyAsk);
        messageReply.push(replyAsk);
      } else {
        // オウム返しする場合、1個の配列で応答メッセージをセット
        messageReply.push({
          type: "text",
          text: bodyRequest.events[0].message.text,
          quoteToken: quoteToken,
        });
      }
    } catch (e) {
      // コンソールにエラーを出しておく
      console.error(e);
      var stringReply = "エラーが発生しました";
      if (e instanceof Error) {
        stringReply = e.message;
      }
      // 応答メッセージをセット
      messageReply.push({
        type: "text",
        text: stringReply,
        quoteToken: quoteToken,
      });
    }
    // 応答メッセージ送信
    await clientLine.replyMessage(
      bodyRequest.events[0].replyToken,
      messageReply
    );
    // OK 返信をセット
    return resultOK;
  }
};

ポイントは下記です。

  • S3、Bedrock への呼び出しインスタンスを定義
  • ask コマンドのユースケースとして以下を実装
    • Bedrock へのメッセージを LINE から来た Input を元に送信
    • Bedrock からの戻り値を S3 へ格納
    • S3 の署名付き URL を取得要求
    • 応答メッセージに画像メッセージとパラメータに URL を指定

いざ実行

その前に

Bedrock の Stability AI を有効にする必要があります。

マネージメントコンソールの Bedrock においてEditを押下して、下記のモデルを有効にしてください。

実行結果

試しに色々と入れてみました。

学習が必要そうですが、日本語も入れられていい感じで見ることが出来ました。

おわりに

今回は自身が持っている LINE ボット環境に Bedrock の画像生成を LINE から要求して取得する機構を盛り込んでみました。

個人的には画像メッセージを base64 エンコードしたテキストでも送信出来れば S3 等の媒体を介さずに良かったのですが、色々と勉強になりました。

手持ちの環境の Lambda ソースが 400 行近くになり、色々とリファクタリングしたいので機会があれば記事にしたいと思います。

アノテーション株式会社について

アノテーション株式会社は、クラスメソッド社のグループ企業として「オペレーション・エクセレンス」を担える企業を目指してチャレンジを続けています。
「らしく働く、らしく生きる」のスローガンを掲げ、様々な背景をもつ多様なメンバーが自由度の高い働き方を通してお客様へサービスを提供し続けてきました。
現在当社では一緒に会社を盛り上げていただけるメンバーを募集中です。
少しでもご興味あれば、アノテーション株式会社WEBサイト をご覧ください。