LlamaIndexを完全に理解するチュートリアル その3:CallbackManagerで内部動作の把握やデバッグを可能にする

内部動作の把握やデバッグは実務では超重要ですよね。第3回はそれを可能にするCallbackManagerを見ていきます。
2023.05.27

こんちには。

データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。

「LlamaIndexを完全に理解するチュートリアル その3」では、CallbackManagerを使って行きたいともいます。

本記事で使用する用語は以下のその1で説明していますので、そちらも参照ください。

LlamaIndexを完全に理解するチュートリアル
その1:処理の概念や流れを理解する基礎編(v0.7.9対応)
その2:テキスト分割のカスタマイズ
その3:CallbackManagerで内部動作の把握やデバッグを可能にする

・本記事の内容はその1のv0.7.9版の記事を投稿後、v0.7.9で動作するように修正しています

本記事の内容

CallbackManagerは前回概要を説明していますが再掲すると、各処理フェーズにおけるstart, endにおけるコールバックをhandlerとして設定することができます。

処理フェーズは、CBEventTypeとして定義されており、以下がその一覧です。

  • CBEventType.CHUNKING : テキスト分割処理の前後
  • CBEventType.NODE_PARSING : NodeParserの前後
  • CBEventType.EMBEDDING : 埋め込みベクトル作成処理の前後
  • CBEventType.LLM : LLM呼び出しの前後
  • CBEventType.QUERY : クエリの開始と終了
  • CBEventType.RETRIEVE : ノード抽出の前後
  • CBEventType.SYNTHESIZE : レスポンス合成の前後
  • CBEventType.TREE : サマリー処理の前後

各処理をフェーズ毎に追うことが可能なので、内部動作の把握や、意図通り動いているか、デバッグなどに有用な機能です。

今回はこちらの使い方を詳しく見ていきます。

環境準備

その1と同様の方法で準備します。

使用したバージョン情報は以下となります。

  • Python : 3.10.11
  • langchain : 0.0.234
  • llama-index : 0.7.9
  • openai : 0.27.8

サンプルコード

ベースのサンプルは以下とします。

from llama_index import SimpleDirectoryReader
from llama_index import ListIndex

documents = SimpleDirectoryReader(input_dir="./data").load_data()

list_index = ListIndex.from_documents(documents)

query_engine = list_index.as_query_engine()

response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")

for i in response.response.split("。"):
    print(i + "。")

こちらにCallbackManagerを設定していきます。

CallbackManagerを設定

CallbackManagerに設定なhandlerとして、事前に準備されているLlamaDebugHandlerがあります。

こちらを使用すると書く処理のロギングを行うことができます。

from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler

documents = SimpleDirectoryReader(input_dir="./data").load_data()

llama_debug_handler = LlamaDebugHandler()
callback_manager = CallbackManager([llama_debug_handler])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)

list_index = ListIndex.from_documents(documents
    , service_context=service_context)

query_engine = list_index.as_query_engine()

response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
**********
Trace: index_construction
    |_node_parsing ->  0.035481 seconds
      |_chunking ->  0.018485 seconds
      |_chunking ->  0.015996 seconds
**********
**********
Trace: query
    |_query ->  33.841973 seconds
      |_retrieve ->  0.001001 seconds
      |_synthesize ->  33.840972 seconds
        |_llm ->  12.805816 seconds
        |_llm ->  7.939219 seconds
        |_llm ->  13.002057 seconds
**********

実行するとこのように、標準出力にもログがでるようになります。

LlamaDebugHandlerからの詳細な結果取得

LlamaDebugHandlerの関数を使って、詳細なログを得ることが可能となっています。

CBEventType.RETRIEVEの場合

まずはCBEventType.RETRIEVEのケースを例に見ていきましょう。

get_event_time_infoでは各CBEventTypeのトータル処理時間や平均処理時間、回数を得ることが可能です。

以下は、ノード選択(CBEventType.RETRIEVE)の一例です。

from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_time_info(CBEventType.RETRIEVE)
EventStats(total_secs=0.008444, average_secs=0.008444, total_count=1)

get_event_pairsは各CBEventTypeにおける[start, end]のペアのリストが取得可能です。

[
    [start1, end1],
    [start2, end2],
    [start3, end3],
]

ペアのデータは具体的には以下のようになっています。(一部長いため整形と省略しています)

from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0]
[
    CBEvent(event_type=<CBEventType.RETRIEVE: 'retrieve'>, payload=None, time='07/18/2023, 08:33:24.061148', id_='7269763f-3e14-4dd0-a158-a9866d085088'),
    CBEvent(event_type=<CBEventType.RETRIEVE: 'retrieve'>, payload={
        <EventPayload.NODES: 'nodes'>: [
            NodeWithScore(node=TextNode(id_='546a7134-2cc8-4bce-9965-b0a2c2ae7e6c', ...), score=None),
            NodeWithScore(node=TextNode(id_='dc5e52cb-f533-4632-8c27-0556f4dfcbca', ...), score=None),
            NodeWithScore(node=TextNode(id_='342e2fff-a3d2-418a-8208-b60a82a4ed90', ...), score=None),
            NodeWithScore(node=TextNode(id_='90195ee8-c4e0-4c09-9651-b344ba97c9c9', ...), score=None),
            NodeWithScore(node=TextNode(id_='528faeda-bef3-443f-9a34-db053fa61b6f', ...), score=None),
            NodeWithScore(node=TextNode(id_='7b1994bc-dbda-418b-858c-2645aed65a68', ...), score=None),
            NodeWithScore(node=TextNode(id_='bcbcbfb3-c886-4a73-befc-017601af618d', ...), score=None),
            NodeWithScore(node=TextNode(id_='4f4c2fd9-055a-427c-bb48-7c29bcc16d0b', ...), score=None)
        ]
    }, time='07/18/2023, 08:33:24.062149', id_='7269763f-3e14-4dd0-a158-a9866d085088')
]

各CBEventTypeでpayloadの内容は異なっており、ここにstartの場合は入力、endの場合は出力のpayloadが含まれます。

例えばCBEventType.RETRIEVEの場合、payloadからどのNodeが選ばれたのかを知ることが可能となっています。

from llama_index.callbacks import CBEventType
node_list = llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0][1].payload["nodes"]
# node_list = response.source_nodes # レスポンスでも同様のことは可能
for node in node_list:
    print(f"doc_id ={node.node.id_}")
doc_id =546a7134-2cc8-4bce-9965-b0a2c2ae7e6c
doc_id =dc5e52cb-f533-4632-8c27-0556f4dfcbca
doc_id =342e2fff-a3d2-418a-8208-b60a82a4ed90
doc_id =90195ee8-c4e0-4c09-9651-b344ba97c9c9
doc_id =528faeda-bef3-443f-9a34-db053fa61b6f
doc_id =7b1994bc-dbda-418b-858c-2645aed65a68
doc_id =bcbcbfb3-c886-4a73-befc-017601af618d
doc_id =4f4c2fd9-055a-427c-bb48-7c29bcc16d0b

データ構造はpayloadをよく見ながらアクセスする必要があります。

CBEventType.LLMの場合

次にLLMへの問い合わせをみるため、CBEventType.LLMで確認してみます。

まずは、get_event_time_infoです。

from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_time_info(CBEventType.LLM)
EventStats(total_secs=31.959478000000004, average_secs=10.653159333333335, total_count=3)

3回LLMへの問い合わせが投げられていることが分かります。

get_event_pairsで実際の入出力を見てみます。

from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_pairs(CBEventType.LLM)[0]
[
    CBEvent(event_type=<CBEventType.LLM: 'llm'>, payload={
        'context_str': 'はい、みなさんこんばんは。 ...',
        <EventPayload.TEMPLATE: 'template'>: <llama_index.prompts.base.Prompt object at 0x00000189BA1B4BE0>
    }, time='07/18/2023, 09:10:22.102886', id_='8ec6b272-74ae-42c8-910a-98954ddf4948'),
    CBEvent(event_type=<CBEventType.LLM: 'llm'>, payload={
        <EventPayload.RESPONSE: 'response'>: '\nAWS Re-Invent 2022では、機械学習に関する多くのアップデートが行われました。...',
        <EventPayload.PROMPT: 'formatted_prompt'>: 'Context information is below.\n---------------------\nはい、みなさんこんばんは。 ...',
        'formatted_prompt_tokens_count': 3778,
        'prediction_tokens_count': 286,
        'total_tokens_used': 4064
    }, time='07/18/2023, 09:10:29.964799', id_='8ec6b272-74ae-42c8-910a-98954ddf4948')
]

このように、LLMへ投げる実際のプロンプトやレスポンスを確認することが可能です。

デバッグや内部動作を把握するのに役立てることができそうです。

自作のhandlerを定義する

BaseCallbackHandlerのサブクラスとして以下のように実装することができます。

from typing import Dict, List, Any, Optional

from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType

class CustomCallbackHandler(BaseCallbackHandler):

    def __init__(
        self,
        event_starts_to_ignore: Optional[List[CBEventType]] = None,
        event_ends_to_ignore: Optional[List[CBEventType]] = None,
        print_trace_on_end: bool = True,
    ) -> None:
        event_starts_to_ignore = event_starts_to_ignore if event_starts_to_ignore else []
        event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
        super().__init__(
            event_starts_to_ignore=event_starts_to_ignore,
            event_ends_to_ignore=event_ends_to_ignore,
        )

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: str = "",
        **kwargs: Any
    ) -> str:
        print(f"event_type = {event_type} (start)")
        return event_id

    def on_event_end(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: str = "",
        **kwargs: Any
    ) -> None:
        print(f"event_type = {event_type} (end)")

    def start_trace(self, trace_id: Optional[str] = None) -> None:
        print("start_trace")

    def end_trace(
        self,
        trace_id: Optional[str] = None,
        trace_map: Optional[Dict[str, List[str]]] = None,
    ) -> None:
        print("end_trace")

簡単なログを今回は入れてみました。on_event_starton_event_endでpayloadを参照できるので、様々なことができそうです。

CallbackManagerに設定してい見ます。

CallbackManagerのhandlerは複数設定が可能です。

LlamaDebugHandlerと自作のハンドラ両方をCallbackManagerに設定してみます。

from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler

documents = SimpleDirectoryReader(input_dir="./data").load_data()

llama_debug_handler = LlamaDebugHandler()
custom_callback_handler = CustomCallbackHandler()
callback_manager = CallbackManager([
    llama_debug_handler
    , custom_callback_handler
])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)

list_index = ListIndex.from_documents(documents
    , service_context=service_context)

query_engine = list_index.as_query_engine()

response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
start_trace
event_type = node_parsing (start)
event_type = chunking (start)
event_type = chunking (end)
event_type = chunking (start)
event_type = chunking (end)
event_type = node_parsing (end)
**********
Trace: index_construction
    |_node_parsing ->  0.033939 seconds
      |_chunking ->  0.016792 seconds
      |_chunking ->  0.016147 seconds
**********
end_trace
start_trace
event_type = query (start)
event_type = retrieve (start)
event_type = retrieve (end)
event_type = synthesize (start)
event_type = llm (start)
event_type = llm (end)
event_type = llm (start)
event_type = llm (end)
event_type = llm (start)
event_type = llm (end)
event_type = synthesize (end)
event_type = query (end)
**********
Trace: query
    |_query ->  32.05093 seconds
      |_retrieve ->  0.002001 seconds
      |_synthesize ->  32.048929 seconds
        |_llm ->  7.861913 seconds
        |_llm ->  11.722238 seconds
        |_llm ->  12.375327 seconds
**********
end_trace

このようなログが出力されました。

start_trace, end_traceは記録の開始、終了で呼ばれるようです。

(今回はListIndex.from_documentslist_index.as_query_engineの2回でtraceが発生しています)

まとめ

いかがでしたでしょうか。

今回はCallbackManagerについて見ていきました。こちらの機能は動作解析やデバッグに役立つと思いますので是非活用をしてみてください。

本記事が、今後LlamaIndexをお使いになられる方の参考になれば幸いです。