シングルGPUで動作するTransformer相当のRNNモデル RWKV-Raven-14Bを試してみた

逆襲のRNN ファインチューニングもできるよ!
2023.04.10

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんちには。

データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。

今回はシングルGPUで動作するTransformer相当のRNNモデル、という噂のRWKVについて試してみたいと思います。

RWKVとは

TransformerベースのLLMと同等の性能を持つ、並列化可能なRNNモデルであり、Attentionフリー(Attention構造を持たない)なモデルです。

ライセンス形態がApache License 2.0かつ、シングルGPUでも動作する点が凄いところとなっています。

Hugging Face側にモデルがいくつか公開されており、rwkv-4が付くものが最近よく話題で使用されているものです。

rwkv-4はその規模に応じて、169m~14bまでのサイズが公開されています。

またファインチューニングの仕組みが準備されているため、自社内でQA集などのデータがあればそれらを学習として使うことも可能だと思います。

ファインチューニングについては以下の記事も参考にされてください。

Ravenとは

RWKVをベースにAlpaca、CodeAlpaca、Guanaco、GPT4All、ShareGPTなどのデータセットでfine-tuningしたモデルです。

以下のHugging FaceのスペースにRaven-RWKV-7Bがホスティングされており、更新も続けられています。

ChatRWKVとは

RWKVやその派生のRavenを使うことができるChat機能のレポジトリとなっています。

こちらを使用する場合、gitでChatRWKVをクローンし、RWKV-4やRavenをダウンロードしてきて使うというイメージで実験していきます。

今回は、ChatRWKVを使用しませんが使われたい方は以下を参考にされてください。

RWKVのモデルについて

ダウンロードするモデルはRaven含め、以下のようにHugging Faceのレポジトリにあります。

モデル選択の他に、浮動小数精度などのstrategyを調整することが可能で、更に低リソースな状況での稼働も想定しています。

ご参考 : rwkv.cpp

rwkv.cppはより低リソースでRWKVを動かすためのレポジトリです。

今回は使用しませんが、かなり性能が落ちるケースも観測されているようですので検証が必要と考えています。

試してみた

では実際に試してみたいと思います。やり方はほぼ以下の記事を踏襲しています。(npakaさんいつもありがとうございます)

実行環境

Google Colaboratoryを使います。Proが必要となります。

ハードウェアアクセラレータはGPU、GPUのタイプはA-100、ラインタイム仕様は標準で実施します。

  • GPU : NVIDIA A100-SXM4-40GB
  • RAM : 85GB
  • Cuda compilation tools : release 11.8, V11.8.89

モデルサイズが27GBのため、RAMがそれ以上に必要なためGPUタイプもA-100にしています。もう少し細かくカスタマイズできるならば、最適な構成を検討できると思います。

とりあえず1つのマシンに載る程度の規模で動作させらるというところがポイントです。

主なバージョン情報は以下です。

!python --version
Python 3.9.16

パッケージのバージョンは以下です。(こちらは準備を一通り実施した後に確認できます)

!pip freeze | grep -e "rwkv" -e "tokenizers"
rwkv==0.7.3
tokenizers==0.13.3

準備

!pip install rwkv
import os
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0'

from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS

モデルとtokenizerの取得

日本語データが多くて「一番いいのを頼む」ということでRWKV-4-Raven-14B-v8EngAndMoreを使ってみます。

!wget https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-14B-v8-EngAndMore-20230408-ctx4096.pth

26.35GBのファイルとなっており、3分程度でダウンロードできました。

tokenizerも以下で取得しておきます。

!wget https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json

ModelとPipelineのインスタンス化

ダウンロードしたモデルなどをロードし、PIPELINEをインスタンス化します。

model = RWKV(
    model="./RWKV-4-Raven-14B-v8-EngAndMore-20230408-ctx4096.pth",
    strategy="cuda fp16"
)

pipeline = PIPELINE(model, "./20B_tokenizer.json")

モデルが約27GBあるためロードには少し時間がかかり、40秒程度必要でした。

generate時にpipelineに渡すパラメータを以下のように定義します。

args = PIPELINE_ARGS(
    temperature = 1.0,
    top_p = 0.7,
    top_k = 100,
    alpha_frequency = 0.25,
    alpha_presence = 0.25,
    token_ban = [],
    token_stop = [0],
    chunk_len = 256)

プロンプトのテンプレート

プロンプトのテンプレートも参考記事を踏襲させていただきます。

def generate_prompt(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

# Instruction:
{instruction}

# Input:
{input}

# Response:
"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

# Instruction:
{instruction}

# Response:
"""

引数としてあたえるものとしては、instructioninputの2つとなっています。

instructionがいわゆる実行すべきタスクを指定するところであり、inputはタスクのコンテキストを必要に応じて指定します。

要約を実行して欲しい場合は、inputに要約される対象の文章を入れるような感じです。

こちらはAlpacaデータセットのフォーマットを踏襲しているようです。

詳細は以下も併せてご覧ください。

タスクをお願いしてみた

まずはコンテキストがないタスクをお願いしてみます。

prompt = generate_prompt("YOLOXとYOLOv5の違いについて教えてください。")

result = pipeline.generate(prompt, token_count=200, args=args)
print(result)

出力は以下となりました。

YOLOXとYOLOv5は、同じような目的を持つ異なるモデルであり、一般的にYOLOXがより小さなモデルであることが多いとされています。YOLOXは、より精度の高い特徴量を提供することを目的としています。YOLOXは、複数のクラスタリングアルゴリズムを使用してユーザーが行うトレーニングに基づいて、トレーニングに必要な正解率を向上させることができます。

きちんと文脈や文法的には、意図を理解してタスクを実行できています。

少しYOLOXの解釈が合っていなさそうな印象ですので、情報の信憑性を求めるのは難しそうです。

次は要約をお願いしてみます。以下の記事から文章を一部コピーして入力として使います。

body = """
Amazon Transcribeに再入門する(2022年12月版)
特集
クラスメソッド 機械学習チーム アドベントカレンダー 2022
#機械学習#Amazon Transcribe
nokomoro3
nokomoro3

2022.12.11
Facebook
1
Hatena
1
Twitter
3
はじめに
こんちには。

データアナリティクス事業本部 機械学習チームの中村です。

本記事は「クラスメソッド 機械学習チーム アドベントカレンダー 2022」の11日目です。

AWSのフルマネージドな機械学習サービスは、PersonalizeやForecastなどに触れる機会が多かったですが、 アドベントカレンダーをきっかけに、その他のAWSのフルマネージドな機械学習サービスに入門していこうと思います。

この流れを継続するかどうか予定は未定ですが、本記事ではAmazon Transcribeに入門していきます。

Amazon Transcribeとは
Amazon Transcribeは音声を入力し、書き起こし(Transcribe)するフルマネージドな機械学習サービスです。

Pricing(料金)
使用する前にPricingを確認していきましょう。(重要)

Pricingを確認することでおおよその機能概要も把握できます。



基本的には、1か月あたりの文字起こしされた音声データの秒数に基づいた従量課金となります(いくぶんか無料枠もあるようですが、期間と時間に制限がある)。 少なくともTokyoリージョンでは、バッチとストリーミングに費用差はなく、1分あたり$0.024ドルからとなっています。 1時間の通話では、60 x $0.024 = $1.44という目安ですね。

上記が標準的な部分ですが、以下を追加することにより費用の追加が発生します。

PII Redaction機能
削除する必要がある社会保障番号や生年月日の情報などの個人を特定できる情報(PII)を除去する機能
標準のおよそ1/10が追加費用としてかかる様子
CLM
カスタム言語モデルを構築し、これらのドメイン固有の用語を認識する機能
例えば、大学の講義動画を書き起こす場合に、科学用語を認識させる、などがユースケース
標準のおよそ1/4が追加費用としてかかる様子
また、これらに加えて特定の用途のTranscribeがあり、通話分析向けにTranscribe Call Analytics、医療向けにTranscribe Medicalがあり、それらは無印のTranscribeよりも価格が割高となっています。

言語毎のサポート状況
通常のTranscribeおよびCall Analyticsについては以下に対応表があります。



以降は、2022年12月10日現在の状況で記載します。

現在の対応状況を見てみましょう。 「en-US」がフルスペックと考えればOKです。これと「ja-JP」の対応状況を比較してみます。

Language    Language Code   Data input  Transcribing numbers    Acronyms    Custom language models  Redaction   Call Analytics
English, US en-US   batch, streaming    batch, streaming    batch, streaming    batch, streaming    batch, streaming    post-call, real-time
Japanese    ja-JP   batch, streaming    no  no  batch, streaming    no  post-call
まだ日本語対応していない部分がありますが、「Custom language models」はつい先日のアップデートで待望の日本語対応しています! 今後他の部分の機能も対応が拡がることを期待しています。



また、Transcribe Medicalについて上記の表にありませんが、以下に「Amazon Transcribe Medical is available in US English (en-US).」と書いてある通り、「en-US」のみの対応となっているようです。



マネコンのリアルタイム処理について
マネジメントコンソールから、マイク入力のリアルタイム書き起こしを試すことができます。



その下部にある設定を見ていきましょう。

Language settings


書き起こしする対象の言語を設定することができます。 「Specific language」だと単一の言語を設定し、「Automatic language identification」は言語を自動で判別します。 自動判別には、発話の最初の3秒間で特定することが可能なようです。

自動判別というものの、複数の指定された言語から自動判別するという形のようで、2言語を最低でも選択する必要があります。 また1言語につき、1つのバリエーションしか指定できないことに注意が必要です(en-USとen-AUを指定することはできない)。

Audio settings


「Speaker partitioning」という機能が有効化でき、この機能は話者の判別(Speaker diarization)を行うことができます。 最大10人のスピーカーの音声を分離することができるようです。

詳細は以下を参照ください。



Content removal settings


「Vocabulary filtering」は、指定した単語をフィルタリングする機能で、事前に「Vocabulary filter」を作成する必要があります。

「PII Identification & redaction」はPricingのところで述べた通り、削除する必要がある社会保障番号や生年月日の情報などの個人を特定できる情報(PII)を同定・除去する機能です。 有効にすると以下のように様々な、Financial情報やPersonal情報をチェックしてそれぞれ有効化することができます。



本機能は追加の料金が発生します。

Customizations


「Custom vocabulary」はユースケースに特化した単語やフレーズの認識精度を向上させるための機能で、「Custom vocabulary」を事前に作成する必要があります。

「Partial results stabilization」は最終結果のみを表示し書き起こし精度を優先するか、部分的な(Partial)結果を表示し高速にレスポンスさせるか、どちらを優先的にするかを調整します。これはリアルタイム処理特有の話です。 実際に自動音声認識(ASR)システムは一般的に、現在の書き起こし結果を後段の音声入力の結果からフィードバックして修正するため、結果が後で変わる可能性があります。

その結果が代わる可能性があるものを途中結果のまま出力する場合は、High Levelを選択して高速に書き起こし、精度を重視したい場合はLow Levelを選択して書き起こします。



最後に「Custom language model」ですが、これもPricingのところで触れた通りでドメイン固有の用語を認識する機能を提供します。 事前に「Custom language model」の作成が必要で、この機能は追加の料金が発生します。
"""

prompt = generate_prompt("文章を要約してください。", input=body)

result = pipeline.generate(prompt, token_count=400, args=args)
print(result)

出力は以下となりました。

  1. 過去の音声情報の分析によって、何が可能か?
  2. 「Call Analytics」はどのような機能を提供しているか?
  3. Transcribeはどのような方法で認識することができるか?
  4. ユーザーが自分でカスタマイズすることはできるか?
  5. どのような事前準備が必要なのか?
  6. Post-call, real-time streamingでは何ができるか?
  7. 「Custom language model」にはどのような制限があるか?
  8. 「Pricing」には、どのような選択肢があるか?
  9. Amazon Transcribeと違い、どのような選択肢があるか?

箇条書きにしてくれてポイントとなる観点が書いてありますが、もう少し中身の要約が欲しいところです。

ここら辺は意図した動作になるようプロンプトの試行錯誤などが今後必要になるかもしれません。

最後に、少し語尾をカスタマイズするなどの指示をしてみましょう。

prompt = generate_prompt(f"ChatGPTとは何ですか?語尾ににゃんを付けて答えてください。")

result = pipeline.generate(prompt, token_count=200, args=args)
print(result)

出力は以下となりました。

ChatGPTは、人工知能によるコンピューター・サービスです。語尾に「にゃん」を付けると、ユーザーの質問に回答することで対話が可能になります。

意図がうまく伝わっていないようです。特定の設定を与えるのは少し難しいのかなという印象を受けました。

指示も微妙にAlpacaデータセットに沿ってない気もしますので、ここも今後は試行錯誤してみたいと思います。

まとめ

いかがでしたでしょうか。完全にGPT-3.5やChatGPT相当とはまだ言えませんが、文法やタスクの理解度的には結構高い性能を出している印象を受けました。

RWKVはファインチューニングにも対応しているようなので、自社内でQA集などのデータがあればそれらを学習として使うことも可能だと思います。

それにより、特定の用途で性能を目指すことも必要なのかなと思いました。

本記事がRWKVを活用される方のご参考になれば幸いです。

参考記事