Amazon BedrockにAWS SDK for Rust からアクセスする

2024.03.28

Introduction

先日、自動でコード修正するAgentをする記事を書きました。
ここではTypescriptを使ってBedrockにアクセスしていたのですが、
今回はRustでAmazon Bedrockにアクセスしてみます。

Environments

  • MacBook Pro (13-inch, M1, 2020)
  • OS : MacOS 14.3.1
  • Rust : 1.76.0
  • aws-cli : 2.15.32

※AWS アカウントは使用可能とします

Setup

Bedrockの準備はこのへんを参考に使えるようにしておきます。
次にCargoで適当なRustプロジェクトを作成して必要なcrateを追加しておきましょう。

% cargo new bedrock-app
% cd bedrock-app

% cargo add aws-config aws-sdk-bedrock aws-sdk-bedrockruntime
% cargo add serde
% cargo add serde-json

Cargo.tomlはこんな感じです。

[dependencies]
aws-config = { version = "1.1.8", features = ["behavior-version-latest"] }
aws-sdk-bedrock = "1.18.0"
aws-sdk-bedrockruntime = "1.18.0"
aws-smithy-runtime-api = "1.2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.114"
tokio = { version = "1", features = ["full"] }

Try

ではBedrockへRustからアクセスしてみます。
main.rs を編集していきましょう。

必要なモジュールをimportしてmain関数でaws_configを作成します。
Claude3モデルを使いたい場合、現状ではus-east-1リージョンだけなので
それを指定します。
定義している構造体はbedrockに渡すための情報です。

use aws_config::meta::region::RegionProviderChain;
use aws_sdk_bedrock::error::SdkError;
use aws_sdk_bedrockruntime::operation::invoke_model::builders::InvokeModelFluentBuilder;
use aws_sdk_bedrockruntime::operation::invoke_model::{InvokeModelError, InvokeModelOutput};
use use aws_sdk_bedrockruntime::primitives::Blob;
use std::borrow::Cow;
use serde_json::Value;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug)]
struct Message {
    role: String,
    content: Vec<Content>,
}

#[derive(Serialize, Deserialize, Debug)]
struct Content {
    r#type: String,
    text: String,
}

#[derive(Serialize, Deserialize, Debug)]
struct Payload {
    anthropic_version: String,
    max_tokens: u32,
    messages: Vec<Message>,
    temperature: f32,
    top_p: f32,
}


#[tokio::main]
async fn main() {

    let region = Region::new("us-east-1");
    let config = aws_config::from_env().region(region).load().await;
    invoke_bedrock(&config, "Rust言語について100文字以内で簡単に教えて下さい。").await;
}

invoke_bedrock の第2引数が prompt です。
この関数ではクライアントを作成し、invoke_model で InvokeModelFluentBuilder を取得して
実行に必要な情報を設定します。
model_id には使いたいモデルの ID を設定します。
ドキュメントによると body と model_id さえ指定すれば動くみたいですが、
content_type を明示的に設定しないと動かなかったので設定します。
send 関数を実行することで実際に Bedrock に prompt を post します。

async fn invoke_bedrock(config: &aws_config::SdkConfig, prompt: &str) {

    let runtime = aws_sdk_bedrockruntime::Client::new(&config);
    let builder: InvokeModelFluentBuilder = runtime.invoke_model();

    //Payload情報を作成してBlog型のBodyを作成
    let payload = build_payload(prompt, 512);
    let payload_json = serde_json::to_vec(&payload).unwrap();
    let body: Blob = Blob::new(payload_json);

    let output = builder
        .model_id("anthropic.claude-3-sonnet-20240229-v1:0")
        .body(body)
        .content_type("application/json")
        .send()
        .await;

    handle_output(output);
}

fn build_payload(prompt: &str, max_tokens: u32) -> Payload {
    Payload {
        anthropic_version: "bedrock-2023-05-31".to_string(),
        max_tokens: max_tokens,
        messages: vec![Message {
            role: "user".to_string(),
            content: vec![Content {
                r#type: "text".to_string(),
                text: prompt.to_string(),
            }],
        }],
        temperature: 0.5,
        top_p: 0.9,
    }
}

あとは結果を表示するだけです。

fn handle_output(
    invoke_model_output: Result<
        InvokeModelOutput,
        SdkError<InvokeModelError, ::aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
    >,
) {
    match invoke_model_output {
        Ok(output) => {
            let response_body = String::from_utf8_lossy(&output.body.as_ref());
            println!("Response: {}", response_body);
        }
        Err(err) => {
            eprintln!("Bedrock Error: {:?}", err);
        }
    }
}

実行すると下記のように表示されます。

% cargo run

Response: {"id":"msg_01XXXXXXXX","type":"message","role":"assistant",
"content":[{"type":"text","text":"Rustは、システムプログラミング言語で、メモリ安全性、並列性、
パフォーマンスに優れています。所有権の概念により、ランタイムのメモリ安全性を保証します。関数型とオブジェクト指向の特徴を併せ持ち、
コンパイル時に多くのエラーを検出できます。Mozillaによって開発され、
システムソフトウェア、Webブラウザエンジン、オペレーティングシステムなどに利用されています。"}],
"model":"claude-3-sonnet-28k-20240229","stop_reason":"end_turn",
"stop_sequence":null,"usage":{"input_tokens":29,"output_tokens":158}}

Stream

Stream でレスポンスを受け取りたい場合、invoke_model_with_response_stream 関数を使って builder を取得します。
あとはそのまま。

async fn invoke_bedrock_stream(config: &aws_config::SdkConfig, prompt: &str) {
    let runtime = aws_sdk_bedrockruntime::Client::new(&config);
    //Stream用
    let builder: InvokeModelWithResponseStreamFluentBuilder =
        runtime.invoke_model_with_response_stream();

    let payload = build_payload(prompt, 512);
    let payload_json = serde_json::to_vec(&payload).unwrap();
    let body: Blob = Blob::new(payload_json);

    let output = builder
        .model_id("anthropic.claude-3-sonnet-20240229-v1:0")
        .body(body)
        .content_type("application/json")
        .send()
        .await;

    handle_output_stream(output).await;
}

レスポンスの処理方法がさきほどと違います。
InvokeModelWithResponseStreamOutput の body をループしながら
recv()で順次 Stream の Chunk をうけとっていきます。

async fn handle_output_stream(
    invoke_model_output: Result<
        InvokeModelWithResponseStreamOutput,
        SdkError<InvokeModelWithResponseStreamError,
            ::aws_smithy_runtime_api::client::orchestrator::HttpResponse,
        >,
    >,
) {
    match invoke_model_output {
        Ok(output) => {
            let mut response_stream = output.body;
            loop {
                match response_stream.recv().await {
                    Ok(Some(aws_sdk_bedrockruntime::types::ResponseStream::Chunk(
                        payload_part,
                    ))) => {
                        if let Some(blob) = &payload_part.bytes {
                            let data: Cow<'_, str> = String::from_utf8_lossy(&blob.as_ref());
                            let value: Value = serde_json::from_str(&data).unwrap();
                            if value["type"] == "content_block_delta" {
                                if let Some(delta) = value["delta"].as_object() {
                                    if let Some(text) = delta["text"].as_str() {
                                        println!("{}", text);
                                    }
                                }
                            }
                        }
                    }
                    Err(err) => {
                        println!("Stream Error");
                    }
                    Ok(None) => {
                        println!("Stream End");
                        break;
                    }
                    Ok(Some(_)) => {
                        println!("other case");
                    }
                }
            }
        }
        Err(err) => {
            eprintln!("Bedrock Error: {:?}", err);
        }
    }
}

ちなみに、生成AIに実装方法を聞くと、しつこくrecvでなくnext(そんな関数はない)を呼べといってくる。

invoke_bedrock_stream を呼ぶように修正して実行すと、
Chunk ごとに文字が表示されていきます。

References