[Rust] Rocketでレスポンスのテキストがちょっとずつ表示されるやつを実装する[Amazon Bedrock]

2024.04.01

Introduction

ChatGPT とかで回答のテキストがタイピングされてるっぽく少しずつ表示されるアレです。
ここで Rust から Bedrock へのアクセスをしてみたので、
今度は Rocket でプロンプトを入力したら結果を Stream で表示するアプリをつくってみます。

Environments

  • MacBook Pro (13-inch, M1, 2020)
  • OS : MacOS 14.3.1
  • Rust : 1.76.0
  • aws-cli : 2.15.32

AWS アカウントは使用可能とします。

Setup

Bedrock の準備はこのへんを参考に使えるようにしておきます。

次に Cargo で Rust プロジェクトを作成して使用する crate を追加します。

% cargo new stream-web-app --bin
% cd stream-web-app

% cargo add aws-config aws-sdk-bedrock aws-sdk-bedrockruntime
% cargo add serde
% cargo add serde-json
・
・
・

使用した Cargo.toml は ↓ です。

[dependencies]
features = "0.10.0"
rocket = { version = "0.5.0", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.115"
tokio = { version = "1.21.2", features = ["full"] }
aws-config = { version = "1.1.8", features = ["behavior-version-latest"] }
aws-sdk-bedrock = "1.18.0"
aws-sdk-bedrockruntime = "1.18.0"
aws-smithy-runtime-api = "1.2.0"

Try

今回は html からプロンプトを入力して結果を取得します。
まずは static/index.html を作成して下記のようにボタンを押したら
Bedrock から回答をとってくるようにします。

<!DOCTYPE html>
<html>
  <head>
    <meta charset="UTF-8" />
    <title>Stream App</title>
  </head>
  <body>
    <textarea id="promptInput"></textarea>
    </br>
    <button id="streamBedrock">Stream Bedrok</button>
    </br>
    <div id="outputContainer"></div>
    <script>
      const streamBedrockButton = document.getElementById('streamBedrock');
      const outputContainer = document.getElementById('outputContainer');
      const promptInput = document.getElementById('promptInput');

      streamBedrockButton.addEventListener('click', async () => {

        const params = new URLSearchParams();
        params.append('my_prompt', promptInput.value);
        const response = await fetch(`/stream/bedrock?${params.toString()}`, {
            method: 'GET',
        });

        const reader = response.body.getReader();
        let content = '';
        while (true) {
          const { value, done } = await reader.read();
          if (done) break;

          const data = new TextDecoder().decode(value);
          content += data;
          //適当に改行
          if(data === "。") {
            content += "</br>";
          }
          outputContainer.innerHTML = content;
        }
      });
    </script>
  </body>
</html>

response.body.getReader()で ReadableStreamReader を取得して
read()でストリームデータを取得した後表示用に処理してます。

src/main.rs では Rocket で Stream レスポンスを返します。
Bedrock へアクセスする方法は以前の記事でやっているのとだいたい同じです。

レスポンスは Rocket のTextStream!マクロを使ってストリーミングレスポンスを生成しています。

use aws_sdk_bedrockruntime::{
    operation::invoke_model_with_response_stream::builders::InvokeModelWithResponseStreamFluentBuilder,
    primitives::Blob, types::ResponseStream::Chunk,
};

use rocket::{get, launch, response::stream::TextStream, routes};
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Serialize, Deserialize, Debug)]
struct Message {
    role: String,
    content: Vec<Content>,
}

#[derive(Serialize, Deserialize, Debug)]
struct Content {
    r#type: String,
    text: String,
}

#[derive(Serialize, Deserialize, Debug)]
struct Payload {
    anthropic_version: String,
    max_tokens: u32,
    messages: Vec<Message>,
    temperature: f32,
    top_p: f32,
}

// クエリパラメータ `my_prompt` を受け取り、BedrockへアクセスしてStream処理
#[get("/bedrock?<my_prompt>")]
async fn stream_bedrok(my_prompt: &str) -> TextStream![String] {
    let region = aws_config::Region::new("us-east-1");
    let config = aws_config::from_env().region(region).load().await;

    let runtime = aws_sdk_bedrockruntime::Client::new(&config);
    let model: InvokeModelWithResponseStreamFluentBuilder =
        runtime.invoke_model_with_response_stream();

    // Payload作成
    let payload = build_payload(my_prompt, 2048);
    let payload_json = serde_json::to_vec(&payload).unwrap();
    let body: Blob = Blob::new(payload_json);

    // Stream用モデル実行
    let invoke_model_output = model
        .model_id("anthropic.claude-3-sonnet-20240229-v1:0")
        .body(body)
        .content_type("application/json")
        .send()
        .await;

    // Streamからテキストを抽出して出力する
    TextStream! {
        match invoke_model_output {
            Ok(output) => {
                let mut response_stream = output.body;
                loop {
                    match response_stream.recv().await {
                        // Streamデータをparse
                        Ok(Some(Chunk(payload_part))) => {
                            if let Some(blob) = &payload_part.bytes {
                                let value: Value =
                                    serde_json::from_str(&String::from_utf8_lossy(&blob.as_ref())).unwrap();
                                if value["type"] == "content_block_delta" {
                                    if let Some(delta) = value["delta"].as_object() {
                                        if let Some(text) = delta["text"].as_str() {
                                            yield text.to_string();
                                        }
                                    }
                                }
                            }
                        }
                        Err(err) => {
                            println!("Stream Error:{:?}",err);
                        }
                        // ストリームの終了
                        Ok(None) => {
                            println!("Stream End");
                            break;
                        }
                        Ok(Some(_)) => {
                            println!("What case?");
                        }
                    }
                }
            }
            Err(err) => {
                eprintln!("Bedrock Error: {:?}", err);
            }
        }
    }
}

// モデルに渡すペイロードを作成
fn build_payload(prompt: &str, max_tokens: u32) -> Payload {
    Payload {
        anthropic_version: "bedrock-2023-05-31".to_string(),
        max_tokens: max_tokens,
        messages: vec![Message {
            role: "user".to_string(),
            content: vec![Content {
                r#type: "text".to_string(),
                text: prompt.to_string(),
            }],
        }],
        temperature: 0.5,
        top_p: 0.9,
    }
}

#[launch]
fn rocket() -> _ {
    rocket::build()
        .mount("/stream", routes![stream_bedrok])
        .mount("/", rocket::fs::FileServer::from("static/"))
}

Bedrock からのデータを Stream で受け取り、
serde で parse して Chunk の文字を yield(TextStream で使える)で返しています。

cargo run で起動して、ブラウザで html にアクセスして実行した結果は ↓ のようになります。

% cargo run
🔧 Configured for debug.
   >> address: 127.0.0.1
   >> port: 8000
   >> workers: 8
   >> max blocking threads: 512
   >> ident: Rocket
   >> IP header: X-Real-IP
   >> limits: bytes = 8KiB, data-form = 2MiB, file = 1MiB, form = 32KiB, json = 1MiB, msgpack = 1MiB, string = 8KiB
   >> http/2: true
   >> keep-alive: 5s
   >> tls: disabled
   >> shutdown: ctrlc = true, force = true, signals = [SIGTERM], grace = 2s, mercy = 3s
   >> log level: normal
   >> cli colors: true
📬 Routes:
   >> (FileServer: static/) GET /<path..> [10]
   >> (stream_bedrok) GET /stream/bedrock?<my_prompt>
📡 Fairings:
   >> Shield (liftoff, response, singleton)
🛡️ Shield:
   >> X-Content-Type-Options: nosniff
   >> X-Frame-Options: SAMEORIGIN
   >> Permissions-Policy: interest-cohort=()
🚀 Rocket has launched from http://127.0.0.1:8000

stream-rocket

References