[新機能]Snowflakeでテキストデータに対するRAGを簡単に実装できる「Cortex Search」を用いてRAGチャットボットを構築してみた

[新機能]Snowflakeでテキストデータに対するRAGを簡単に実装できる「Cortex Search」を用いてRAGチャットボットを構築してみた

Clock Icon2024.07.25

さがらです。

2024年7月25日に、Summitでも話題になった「Cortex Search」がパブリックプレビューとなりました!

https://docs.snowflake.com/en/release-notes/2024/other/2024-07-25-cortex-search-preview

Cortex Searchは、大規模言語モデル (LLM) を活用した検索拡張生成 (RAG) を簡単に構築でき、Embedding、インフラのメンテナンス、パラメータの調整、インデックスの更新などを気にすることなく、構築・運用がすることができる機能です。

本機能に関する詳細は下記の公式ドキュメントをご覧ください。(2024年7月25日時点では、英語のテキストに特化していることだけご注意ください。また、リージョンも米国とヨーロッパの一部のリージョンのみに限られています。)

https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview#known-limitations

今回、サクッと試すことを目的に下記のリポジトリの内容に沿って、Cortex SearchとStreamlit-in-Snowflakeを用いたRAGチャットボットを構築してみます。(AWS US West 2 (Oregon)で試しています。)

https://github.com/Snowflake-Labs/cortex-search/tree/main/examples/streamlit-chat

Cortex Search Serviceの構築まで

まず、下記のクエリを実行して必要なデータベースやウェアハウスを構築します。

-- Create the database, schema, stage, and warehouse for the demo.
CREATE DATABASE demo_cortex_search;
CREATE SCHEMA fomc;
CREATE STAGE minutes 
	DIRECTORY = ( ENABLE = true ) 
	ENCRYPTION = ( TYPE = 'SNOWFLAKE_SSE' );
CREATE WAREHOUSE demo_cortex_search_wh;

この上で、作成したステージに下記リンク先のGoogleドライブから取得できるPDFをアップロードします。内容としては、米国の中央銀行制度であるFederal Reserve Systemの最高意思決定機関であるFederal Reserve Board(連邦準備制度理事会)のFederal Open Market Committee(連邦公開市場委員会)の議事録です。

https://drive.google.com/drive/folders/1_erdfr7ZR49Ub2Sw-oGMJLs3KvJSap0d

2024-07-25_21h04_33.png

ステージへのアップロードは、Snowsightから行うとドラッグ&ドロップで行えるのでおすすめです。

2024-07-25_21h06_07.png

2024-07-25_21h06_38.png

次に、下記のSQLを実行します。PDFを解析して抽出したテキストをチャンクするためのユーザー定義テーブル関数(UDTF)を作成します。

CREATE OR REPLACE FUNCTION pypdf_extract_and_chunk(file_url VARCHAR, chunk_size INTEGER, overlap INTEGER)
RETURNS TABLE (chunk VARCHAR)
LANGUAGE PYTHON
RUNTIME_VERSION = '3.9'
HANDLER = 'pdf_text_chunker'
PACKAGES = ('snowflake-snowpark-python','PyPDF2', 'langchain')
AS
$$
from snowflake.snowpark.types import StringType, StructField, StructType
from langchain.text_splitter import RecursiveCharacterTextSplitter
from snowflake.snowpark.files import SnowflakeFile
import PyPDF2, io
import logging
import pandas as pd

class pdf_text_chunker:

    def read_pdf(self, file_url: str) -> str:

        logger = logging.getLogger("udf_logger")
        logger.info(f"Opening file {file_url}")

        with SnowflakeFile.open(file_url, 'rb') as f:
            buffer = io.BytesIO(f.readall())

        reader = PyPDF2.PdfReader(buffer)   
        text = ""
        for page in reader.pages:
            try:
                text += page.extract_text().replace('\n', ' ').replace('\0', ' ')
            except:
                text = "Unable to Extract"
                logger.warn(f"Unable to extract from file {file_url}, page {page}")

        return text

    def process(self,file_url: str, chunk_size: int, chunk_overlap: int):

        text = self.read_pdf(file_url)

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size = chunk_size,
            chunk_overlap  = chunk_overlap,
            length_function = len
        )

        chunks = text_splitter.split_text(text)
        df = pd.DataFrame(chunks, columns=['CHUNK'])

        yield from df.itertuples(index=False, name=None)
$$;

次に、下記のSQLを実行します。先程定義したUDTFを用いてステージにアップロードしたPDFを解析し、チャンクされた文書をテーブルにロードします。

CREATE OR REPLACE TABLE parsed_doc_chunks ( 
    relative_path VARCHAR, -- Relative path to the PDF file
    chunk VARCHAR
) AS (
    SELECT
        relative_path,
        chunks.chunk as chunk
    FROM
        directory(@DEMO_CORTEX_SEARCH.FOMC.MINUTES)
        , TABLE(pypdf_extract_and_chunk(
            build_scoped_file_url(@MINUTES, relative_path),
            2000,
            500
        )) as chunks
);

実際にSELECT文を実行してみると、下図のようにデータをロード出来ていることがわかります。

SELECT * FROM parsed_doc_chunks;

2024-07-25_21h14_25.png

次に、下記のSQLを実行します。このSQLにより、先ほどロードしたテーブルのデータを用いてCortex Search Serviceが構築されます。

CREATE OR REPLACE CORTEX SEARCH SERVICE fomc_minutes_search_service
ON minutes
ATTRIBUTES relative_path
WAREHOUSE = demo_cortex_search_wh
TARGET_LAG = '1 hour'
AS (
    SELECT
        LEFT(RIGHT(relative_path, 12), 8) as meeting_date,
        CONCAT('Meeting date: ', meeting_date, ' \nMinutes: ', chunk) as minutes,
        relative_path
    FROM parsed_doc_chunks
);

最後に、構築したCortex Search Serviceを利用できるように、PUBLICロールに対して権限を付与します。

GRANT USAGE ON CORTEX SEARCH SERVICE fomc_minutes_search_service TO ROLE public;
GRANT USAGE ON DATABASE demo_cortex_search to role public;
GRANT USAGE ON SCHEMA demo_cortex_search.fomc to role public;
GRANT READ ON STAGE demo_cortex_search.fomc.minutes to role public;

Streamlit in Snowflakeでチャットボットを構築

次に、Streamlit-in-Snowflakeを用いてチャットボットを構築します。

まず、Snowsight上で先ほどCortex Search Serviceを構築したスキーマと同じスキーマに、Streamlitのアプリを構築します。

2024-07-25_21h22_31.png

次にPackagesから、下記の2つのpackageを追加します。

  • snowflake==0.8.0
  • snowflake-ml-python==1.5.1

2024-07-25_21h30_23.png

この上で、コードを下記の内容にまるごと書き換えます。

import streamlit as st
from snowflake.core import Root # requires snowflake>=0.8.0
from snowflake.cortex import Complete
from snowflake.snowpark.context import get_active_session

MODELS = [
    "mistral-large",
    "snowflake-arctic",
    "llama3-70b",
    "llama3-8b",
]

def init_messages():
    """
    Initialize the session state for chat messages. If the session state indicates that the
    conversation should be cleared or if the "messages" key is not in the session state,
    initialize it as an empty list.
    """
    if st.session_state.clear_conversation or "messages" not in st.session_state:
        st.session_state.messages = []

def init_service_metadata():
    """
    Initialize the session state for cortex search service metadata. Query the available
    cortex search services from the Snowflake session and store their names and search
    columns in the session state.
    """
    if "service_metadata" not in st.session_state:
        services = session.sql("SHOW CORTEX SEARCH SERVICES;").collect()
        service_metadata = []
        if services:
            # TODO: remove loop once changes land to add the column metadata in SHOW
            for s in services:
                svc_name = s["name"]
                svc_search_col = session.sql(
                    f"DESC CORTEX SEARCH SERVICE {svc_name};"
                ).collect()[0]["search_column"]
                service_metadata.append(
                    {"name": svc_name, "search_column": svc_search_col}
                )

        st.session_state.service_metadata = service_metadata

def init_config_options():
    """
    Initialize the configuration options in the Streamlit sidebar. Allow the user to select
    a cortex search service, clear the conversation, toggle debug mode, and toggle the use of
    chat history. Also provide advanced options to select a model, the number of context chunks,
    and the number of chat messages to use in the chat history.
    """
    st.sidebar.selectbox(
        "Select cortex search service:",
        [s["name"] for s in st.session_state.service_metadata],
        key="selected_cortex_search_service",
    )

    st.sidebar.button("Clear conversation", key="clear_conversation")
    st.sidebar.toggle("Debug", key="debug", value=False)
    st.sidebar.toggle("Use chat history", key="use_chat_history", value=True)

    with st.sidebar.expander("Advanced options"):
        st.selectbox("Select model:", MODELS, key="model_name")
        st.number_input(
            "Select number of context chunks",
            value=5,
            key="num_retrieved_chunks",
            min_value=1,
            max_value=10,
        )
        st.number_input(
            "Select number of messages to use in chat history",
            value=5,
            key="num_chat_messages",
            min_value=1,
            max_value=10,
        )

    st.sidebar.expander("Session State").write(st.session_state)

def query_cortex_search_service(query):
    """
    Query the selected cortex search service with the given query and retrieve context documents.
    Display the retrieved context documents in the sidebar if debug mode is enabled. Return the
    context documents as a string.

    Args:
        query (str): The query to search the cortex search service with.

    Returns:
        str: The concatenated string of context documents.
    """
    db, schema = session.get_current_database(), session.get_current_schema()

    cortex_search_service = (
        root.databases[db]
        .schemas[schema]
        .cortex_search_services[st.session_state.selected_cortex_search_service]
    )

    context_documents = cortex_search_service.search(
        query, columns=[], limit=st.session_state.num_retrieved_chunks
    )
    results = context_documents.results

    service_metadata = st.session_state.service_metadata
    search_col = [s["search_column"] for s in service_metadata
                    if s["name"] == st.session_state.selected_cortex_search_service][0]

    context_str = ""
    for i, r in enumerate(results):
        context_str += f"Context document {i+1}: {r[search_col]} \n" + "\n"

    if st.session_state.debug:
        st.sidebar.text_area("Context documents", context_str, height=500)

    return context_str

def get_chat_history():
    """
    Retrieve the chat history from the session state limited to the number of messages specified
    by the user in the sidebar options.

    Returns:
        list: The list of chat messages from the session state.
    """
    start_index = max(
        0, len(st.session_state.messages) - st.session_state.num_chat_messages
    )
    return st.session_state.messages[start_index : len(st.session_state.messages) - 1]

def complete(model, prompt):
    """
    Generate a completion for the given prompt using the specified model.

    Args:
        model (str): The name of the model to use for completion.
        prompt (str): The prompt to generate a completion for.

    Returns:
        str: The generated completion.
    """
    return Complete(model, prompt).replace("$", "\$")

def make_chat_history_summary(chat_history, question):
    """
    Generate a summary of the chat history combined with the current question to extend the query
    context. Use the language model to generate this summary.

    Args:
        chat_history (str): The chat history to include in the summary.
        question (str): The current user question to extend with the chat history.

    Returns:
        str: The generated summary of the chat history and question.
    """
    prompt = f"""
        [INST]
        Based on the chat history below and the question, generate a query that extend the question
        with the chat history provided. The query should be in natural language. 
        Answer with only the query. Do not add any explanation.

        <chat_history>
        {chat_history}
        </chat_history>
        <question>
        {question}
        </question>
        [/INST]
    """

    summary = complete(st.session_state.model_name, prompt)

    if st.session_state.debug:
        st.sidebar.text_area(
            "Chat history summary", summary.replace("$", "\$"), height=150
        )

    return summary

def create_prompt(user_question):
    """
    Create a prompt for the language model by combining the user question with context retrieved
    from the cortex search service and chat history (if enabled). Format the prompt according to
    the expected input format of the model.

    Args:
        user_question (str): The user's question to generate a prompt for.

    Returns:
        str: The generated prompt for the language model.
    """
    if st.session_state.use_chat_history:
        chat_history = get_chat_history()
        if chat_history != []:
            question_summary = make_chat_history_summary(chat_history, user_question)
            prompt_context = query_cortex_search_service(question_summary)
        else:
            prompt_context = query_cortex_search_service(user_question)
    else:
        prompt_context = query_cortex_search_service(user_question)
        chat_history = ""

    prompt = f"""
            [INST]
            You are a helpful AI chat assistant with RAG capabilities. When a user asks you a question,
            you will also be given context provided between <context> and </context> tags. Use that context
            with the user's chat history provided in the between <chat_history> and </chat_history> tags
            to provide a summary that addresses the user's question. Ensure the answer is coherent, concise,
            and directly relevant to the user's question.

            If the user asks a generic question which cannot be answered with the given context or chat_history,
            just say "I don't know the answer to that question.

            Don't saying things like "according to the provided context".

            <chat_history>
            {chat_history}
            </chat_history>
            <context>          
            {prompt_context}
            </context>
            <question>  
            {user_question}
            </question>
            [/INST]
            Answer:
           """
    return prompt

def main():
    st.title(f":speech_balloon: Chatbot with Snowflake Cortex")

    init_service_metadata()
    init_config_options()
    init_messages()

    icons = {"assistant": "❄️", "user": "👤"}

    # Display chat messages from history on app rerun
    for message in st.session_state.messages:
        with st.chat_message(message["role"], avatar=icons[message["role"]]):
            st.markdown(message["content"])

    disable_chat = (
        "service_metadata" not in st.session_state
        or len(st.session_state.service_metadata) == 0
    )
    if question := st.chat_input("Ask a question...", disabled=disable_chat):
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": question})
        # Display user message in chat message container
        with st.chat_message("user", avatar=icons["user"]):
            st.markdown(question.replace("$", "\$"))

        # Display assistant response in chat message container
        with st.chat_message("assistant", avatar=icons["assistant"]):
            message_placeholder = st.empty()
            question = question.replace("'", "")
            with st.spinner("Thinking..."):
                generated_response = complete(
                    st.session_state.model_name, create_prompt(question)
                )
                message_placeholder.markdown(generated_response)

        st.session_state.messages.append(
            {"role": "assistant", "content": generated_response}
        )

if __name__ == "__main__":
    session = get_active_session()
    root = Root(session)
    main()

この上で、RUNを押すと下図のようなチャットボットアプリケーションが起動します。

2024-07-25_21h33_35.png

画面左のAdvanced Optionsでは、回答の生成に使用するLLMのモデル、各回答で使用するcontext chunksの数、新しい応答の生成に使用するチャット履歴メッセージの深さを選択できます。

2024-07-25_21h34_55.png

2024-07-25_21h35_12.png

画面左のSession Stateでは、チャットのメッセージや現在選択されているサービスなどのセッション状態が表示されます。

2024-07-25_21h36_29.png

実際に問い合わせてみた

では、実際に構築したRAGチャットボット

まず、「how was gpd growth in q4 23?」と聞いてみます。すると、「gpd」という誤字があっても議事録に記載されていたGDPに関する情報を元に結果を返してくれました。

2024-07-25_21h41_34.png

次に、「how was unemployment in the same quarter?」と聞いてみます。すると、「same quarter」を「fourth quarter 2023」と正しく認識して回答を返してくれました。

2024-07-25_21h43_05.png

次に、「how has the fed's view of the market change over the course of 2024?」と聞いてみます。すると、複数の議事録に渡っての情報が必要な質問でしたが、問題なく回答してくれました。

2024-07-25_21h44_25.png

最後に、「What was janet yellen's opinion about 2024 q1?」と聞いてみます。すると、この質問に対する回答をするためのデータがないため、回答ができないと返ってきました。

2024-07-25_21h46_03.png

最後に

Snowflakeでテキストデータに対するRAGを簡単に実装できる「Cortex Search」がパブリックプレビューとなったので、こちらのサンプルリポジトリに沿って用いてRAGチャットボットを構築してみました。

このブログを書きながら検証したのですが、1時間ほどで構築してブログも書いて公開できるスピードでRAGチャットボットを構築することができました!
まだパブリックプレビューという状況ですが、今後は日本のリージョンでも使えたり、日本語の対応にも期待したいところです!

この記事をシェアする

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.