こんちには。
データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。
「LlamaIndexを完全に理解するチュートリアル その3」では、CallbackManagerを使って行きたいともいます。
本記事で使用する用語は以下のその1で説明していますので、そちらも参照ください。
LlamaIndexを完全に理解するチュートリアル |
---|
その1:処理の概念や流れを理解する基礎編(v0.7.9対応) |
その2:テキスト分割のカスタマイズ |
その3:CallbackManagerで内部動作の把握やデバッグを可能にする |
・本記事の内容はその1のv0.7.9版の記事を投稿後、v0.7.9で動作するように修正しています
本記事の内容
CallbackManagerは前回概要を説明していますが再掲すると、各処理フェーズにおけるstart, endにおけるコールバックをhandlerとして設定することができます。
処理フェーズは、CBEventTypeとして定義されており、以下がその一覧です。
- CBEventType.CHUNKING : テキスト分割処理の前後
- CBEventType.NODE_PARSING : NodeParserの前後
- CBEventType.EMBEDDING : 埋め込みベクトル作成処理の前後
- CBEventType.LLM : LLM呼び出しの前後
- CBEventType.QUERY : クエリの開始と終了
- CBEventType.RETRIEVE : ノード抽出の前後
- CBEventType.SYNTHESIZE : レスポンス合成の前後
- CBEventType.TREE : サマリー処理の前後
各処理をフェーズ毎に追うことが可能なので、内部動作の把握や、意図通り動いているか、デバッグなどに有用な機能です。
今回はこちらの使い方を詳しく見ていきます。
環境準備
その1と同様の方法で準備します。
使用したバージョン情報は以下となります。
- Python : 3.10.11
- langchain : 0.0.234
- llama-index : 0.7.9
- openai : 0.27.8
サンプルコード
ベースのサンプルは以下とします。
from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
documents = SimpleDirectoryReader(input_dir="./data").load_data()
list_index = ListIndex.from_documents(documents)
query_engine = list_index.as_query_engine()
response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
for i in response.response.split("。"):
print(i + "。")
こちらにCallbackManagerを設定していきます。
CallbackManagerを設定
CallbackManagerに設定なhandlerとして、事前に準備されているLlamaDebugHandlerがあります。
こちらを使用すると書く処理のロギングを行うことができます。
from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler
documents = SimpleDirectoryReader(input_dir="./data").load_data()
llama_debug_handler = LlamaDebugHandler()
callback_manager = CallbackManager([llama_debug_handler])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)
list_index = ListIndex.from_documents(documents
, service_context=service_context)
query_engine = list_index.as_query_engine()
response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
**********
Trace: index_construction
|_node_parsing -> 0.035481 seconds
|_chunking -> 0.018485 seconds
|_chunking -> 0.015996 seconds
**********
**********
Trace: query
|_query -> 33.841973 seconds
|_retrieve -> 0.001001 seconds
|_synthesize -> 33.840972 seconds
|_llm -> 12.805816 seconds
|_llm -> 7.939219 seconds
|_llm -> 13.002057 seconds
**********
実行するとこのように、標準出力にもログがでるようになります。
LlamaDebugHandlerからの詳細な結果取得
LlamaDebugHandlerの関数を使って、詳細なログを得ることが可能となっています。
CBEventType.RETRIEVEの場合
まずはCBEventType.RETRIEVEのケースを例に見ていきましょう。
get_event_time_info
では各CBEventTypeのトータル処理時間や平均処理時間、回数を得ることが可能です。
以下は、ノード選択(CBEventType.RETRIEVE)の一例です。
from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_time_info(CBEventType.RETRIEVE)
EventStats(total_secs=0.008444, average_secs=0.008444, total_count=1)
get_event_pairs
は各CBEventTypeにおける[start, end]のペアのリストが取得可能です。
[
[start1, end1],
[start2, end2],
[start3, end3],
]
ペアのデータは具体的には以下のようになっています。(一部長いため整形と省略しています)
from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0]
[
CBEvent(event_type=<CBEventType.RETRIEVE: 'retrieve'>, payload=None, time='07/18/2023, 08:33:24.061148', id_='7269763f-3e14-4dd0-a158-a9866d085088'),
CBEvent(event_type=<CBEventType.RETRIEVE: 'retrieve'>, payload={
<EventPayload.NODES: 'nodes'>: [
NodeWithScore(node=TextNode(id_='546a7134-2cc8-4bce-9965-b0a2c2ae7e6c', ...), score=None),
NodeWithScore(node=TextNode(id_='dc5e52cb-f533-4632-8c27-0556f4dfcbca', ...), score=None),
NodeWithScore(node=TextNode(id_='342e2fff-a3d2-418a-8208-b60a82a4ed90', ...), score=None),
NodeWithScore(node=TextNode(id_='90195ee8-c4e0-4c09-9651-b344ba97c9c9', ...), score=None),
NodeWithScore(node=TextNode(id_='528faeda-bef3-443f-9a34-db053fa61b6f', ...), score=None),
NodeWithScore(node=TextNode(id_='7b1994bc-dbda-418b-858c-2645aed65a68', ...), score=None),
NodeWithScore(node=TextNode(id_='bcbcbfb3-c886-4a73-befc-017601af618d', ...), score=None),
NodeWithScore(node=TextNode(id_='4f4c2fd9-055a-427c-bb48-7c29bcc16d0b', ...), score=None)
]
}, time='07/18/2023, 08:33:24.062149', id_='7269763f-3e14-4dd0-a158-a9866d085088')
]
各CBEventTypeでpayloadの内容は異なっており、ここにstartの場合は入力、endの場合は出力のpayloadが含まれます。
例えばCBEventType.RETRIEVEの場合、payloadからどのNodeが選ばれたのかを知ることが可能となっています。
from llama_index.callbacks import CBEventType
node_list = llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0][1].payload["nodes"]
# node_list = response.source_nodes # レスポンスでも同様のことは可能
for node in node_list:
print(f"doc_id ={node.node.id_}")
doc_id =546a7134-2cc8-4bce-9965-b0a2c2ae7e6c
doc_id =dc5e52cb-f533-4632-8c27-0556f4dfcbca
doc_id =342e2fff-a3d2-418a-8208-b60a82a4ed90
doc_id =90195ee8-c4e0-4c09-9651-b344ba97c9c9
doc_id =528faeda-bef3-443f-9a34-db053fa61b6f
doc_id =7b1994bc-dbda-418b-858c-2645aed65a68
doc_id =bcbcbfb3-c886-4a73-befc-017601af618d
doc_id =4f4c2fd9-055a-427c-bb48-7c29bcc16d0b
データ構造はpayloadをよく見ながらアクセスする必要があります。
CBEventType.LLMの場合
次にLLMへの問い合わせをみるため、CBEventType.LLMで確認してみます。
まずは、get_event_time_info
です。
from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_time_info(CBEventType.LLM)
EventStats(total_secs=31.959478000000004, average_secs=10.653159333333335, total_count=3)
3回LLMへの問い合わせが投げられていることが分かります。
get_event_pairs
で実際の入出力を見てみます。
from llama_index.callbacks import CBEventType
llama_debug_handler.get_event_pairs(CBEventType.LLM)[0]
[
CBEvent(event_type=<CBEventType.LLM: 'llm'>, payload={
'context_str': 'はい、みなさんこんばんは。 ...',
<EventPayload.TEMPLATE: 'template'>: <llama_index.prompts.base.Prompt object at 0x00000189BA1B4BE0>
}, time='07/18/2023, 09:10:22.102886', id_='8ec6b272-74ae-42c8-910a-98954ddf4948'),
CBEvent(event_type=<CBEventType.LLM: 'llm'>, payload={
<EventPayload.RESPONSE: 'response'>: '\nAWS Re-Invent 2022では、機械学習に関する多くのアップデートが行われました。...',
<EventPayload.PROMPT: 'formatted_prompt'>: 'Context information is below.\n---------------------\nはい、みなさんこんばんは。 ...',
'formatted_prompt_tokens_count': 3778,
'prediction_tokens_count': 286,
'total_tokens_used': 4064
}, time='07/18/2023, 09:10:29.964799', id_='8ec6b272-74ae-42c8-910a-98954ddf4948')
]
このように、LLMへ投げる実際のプロンプトやレスポンスを確認することが可能です。
デバッグや内部動作を把握するのに役立てることができそうです。
自作のhandlerを定義する
BaseCallbackHandlerのサブクラスとして以下のように実装することができます。
from typing import Dict, List, Any, Optional
from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType
class CustomCallbackHandler(BaseCallbackHandler):
def __init__(
self,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
print_trace_on_end: bool = True,
) -> None:
event_starts_to_ignore = event_starts_to_ignore if event_starts_to_ignore else []
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any
) -> str:
print(f"event_type = {event_type} (start)")
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any
) -> None:
print(f"event_type = {event_type} (end)")
def start_trace(self, trace_id: Optional[str] = None) -> None:
print("start_trace")
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
print("end_trace")
簡単なログを今回は入れてみました。on_event_start
とon_event_end
でpayloadを参照できるので、様々なことができそうです。
CallbackManagerに設定してい見ます。
CallbackManagerのhandlerは複数設定が可能です。
LlamaDebugHandlerと自作のハンドラ両方をCallbackManagerに設定してみます。
from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler
documents = SimpleDirectoryReader(input_dir="./data").load_data()
llama_debug_handler = LlamaDebugHandler()
custom_callback_handler = CustomCallbackHandler()
callback_manager = CallbackManager([
llama_debug_handler
, custom_callback_handler
])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)
list_index = ListIndex.from_documents(documents
, service_context=service_context)
query_engine = list_index.as_query_engine()
response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
start_trace
event_type = node_parsing (start)
event_type = chunking (start)
event_type = chunking (end)
event_type = chunking (start)
event_type = chunking (end)
event_type = node_parsing (end)
**********
Trace: index_construction
|_node_parsing -> 0.033939 seconds
|_chunking -> 0.016792 seconds
|_chunking -> 0.016147 seconds
**********
end_trace
start_trace
event_type = query (start)
event_type = retrieve (start)
event_type = retrieve (end)
event_type = synthesize (start)
event_type = llm (start)
event_type = llm (end)
event_type = llm (start)
event_type = llm (end)
event_type = llm (start)
event_type = llm (end)
event_type = synthesize (end)
event_type = query (end)
**********
Trace: query
|_query -> 32.05093 seconds
|_retrieve -> 0.002001 seconds
|_synthesize -> 32.048929 seconds
|_llm -> 7.861913 seconds
|_llm -> 11.722238 seconds
|_llm -> 12.375327 seconds
**********
end_trace
このようなログが出力されました。
start_trace, end_traceは記録の開始、終了で呼ばれるようです。
(今回はListIndex.from_documents
とlist_index.as_query_engine
の2回でtraceが発生しています)
まとめ
いかがでしたでしょうか。
今回はCallbackManagerについて見ていきました。こちらの機能は動作解析やデバッグに役立つと思いますので是非活用をしてみてください。
本記事が、今後LlamaIndexをお使いになられる方の参考になれば幸いです。