Whisperなどの前段処理に使えるかも!? PythonのWebRTC VADを使って 音声分割を検討してみた

2023.11.26

こんちには。

データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。

今回は音声区間を検出する、WebRTC VADをPythonから使用できるpy-webrtcvadを試してみたのでご紹介したいと思います。

はじめに

近年、Whisperに代表される文字起こし処理などの様々なタスクで高精度なモデルが出てきており、API化やフレームワーク経由で呼び出すことにより、簡単に使い始めることができるようになってきています。

その一方でこれらの多くは、Transformerなどのある程度規模の大きな機械学習モデルで構成されているため、ストリーミングデータではなくある程度まとまった音声をバッチ処理する前提であったり、バッチ処理をするにしてもデータの長さに関する制限があったりするため、実際にアプリケーションに組み込む際には課題となることがあります。

そのため通常、ストリーミング処理をしたり、長い音声を処理するためには音声を分割することを検討する必要があります。

しかし音声を単純に長さなどで分割すると、発話の途中で途切れてしまうことが考えられますので、その後段の文字起こしなどの本来やりたい処理が劣化する要因となります。そのため音声区間を検出することは現在でも非常に重要です。

これらの音声区間の検出は古くから存在し、一般的にVADと呼ばれ、ITU-T勧告の音声圧縮処理の一部(無音区間の帯域削減)などとして記載されています。

私もこれまでのブログでは、音声の分割を処理する際にPythonから容易に実行できるpydubなどのライブラリのsplit_on_silenceを使用していましたが、もう少し他の手法と比較したく別のライブラリを探しており、py-webrtcvadというライブラリがあることを知りました。

py-webrtcvadはライセンス的にも使いやすく、WebRTCに使用されるVADがベースとなっているため、ストリーミングデータのリアルタイム処理などに向いていると考え、今回少し試してみました。

(ここは用途がストリーミングなどのリアルタイムでなくても良い場合は、VAD自体をDeep Learningなどの機械学習モデルで実装したものも選択肢に入ってくると考えられます)

py-webrtcvadについて

py-webrtcvadは、WebRTC Voice Activity Detector (VAD)のPythonインターフェースとなっています。以下がGitHubのレポジトリです。

WebRTC VADは、GoogleがWebRTCプロジェクトのために開発したVADで、短い音声データ(フレーム)単位に音声(voiced)か非音声(unvoiced)に分類します。

その他の仕様は以下の通りです。

  • Agressive Modeを設定可能で0,1,2,3の4段階で指定可能
    • 0は非音声のフィルタリングについて最も抑制的
    • 3は非音声のフィルタリングについて最も積極的
  • サンプリング周波数は8000、16000、32000、48000 Hzに対応
  • 16bit モノラル PCMデータのみに対応
  • フレームは10、20、30 ミリ秒に対応

py-webrtcvad自体は非常にシンプルなのですが短い単位のVADのみを実行するため、実際にはある程度過去のVAD結果の履歴を考慮して、フレームをまとめるような処理が別途必要となります。その実装例としてexample.pyが準備されています。

こちらは履歴を考慮しつつ音声区間を検出し、音声区間のみを抽出するような例となっています。今回はこちらを流用しつつ、音声区間だけではなく非音声区間も結果として返すように変更を加えていきたいと思います。

使ってみる

環境構築と動作環境

pip等でインストールを実施します。(ここは環境に応じて合わせてください)

!pip install webrtcvad

動作環境は以下です。

  • Python 3.11.5
  • webrtcvad==2.0.10

サンプルコードの確認と変更

example.pyというサンプルにいくつかモジュールがあるので順に見ていきます。

read_wave

read_waveは、wavファイルをbytesデータとして読み込むためのメソッドです。

def read_wave(path):
    """Reads a .wav file.

    Takes the path, and returns (PCM audio data, sample rate).
    """
    with contextlib.closing(wave.open(path, 'rb')) as wf:
        num_channels = wf.getnchannels()
        assert num_channels == 1
        sample_width = wf.getsampwidth()
        assert sample_width == 2
        sample_rate = wf.getframerate()
        assert sample_rate in (8000, 16000, 32000, 48000)
        pcm_data = wf.readframes(wf.getnframes())
        return pcm_data, sample_rate

ここはこのまま使って行きます。

write_wave

write_waveは、逆にbytesデータをwavファイルとして書き込むためのメソッドです。

def write_wave(path, audio, sample_rate):
    """Writes a .wav file.

    Takes path, PCM audio data, and sample rate.
    """
    with contextlib.closing(wave.open(path, 'wb')) as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(sample_rate)
        wf.writeframes(audio)

ここはこのまま使って行きます。

frame_generator

frame_generatorは、bytesデータをあるミリ秒単位(フレーム単位)に分割するメソッドです。

generatorとなるため実際に分割結果を得る際にはlist()などで囲って、実体化させる必要があります。

分割単位は、別途定義されているFrameクラスとして得られます。

class Frame(object):
    """Represents a "frame" of audio data."""
    def __init__(self, bytes, timestamp, duration):
        self.bytes = bytes
        self.timestamp = timestamp
        self.duration = duration

def frame_generator(frame_duration_ms, audio, sample_rate):
    """Generates audio frames from PCM audio data.

    Takes the desired frame duration in milliseconds, the PCM data, and
    the sample rate.

    Yields Frames of the requested duration.
    """
    n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
    offset = 0
    timestamp = 0.0
    duration = (float(n) / sample_rate) / 2.0
    while offset + n < len(audio):
        yield Frame(audio[offset:offset + n], timestamp, duration)
        timestamp += duration
        offset += n

    # 追記 : 末尾のデータも欲しいので
    yield Frame(audio[offset:offset + n], timestamp, duration)

このコードは追記としている部分だけ修正して使用します。

vad_collector

vad_collectorは、実際にフレーム単位の履歴を考慮した処理を行い、メインの部分となります。

ここは元のコードから変更します。元々の処理はざっくりいうと以下のようになっています。

  • ガード区間を実装
    • フレーム毎の判定結果の継続性で最終的な音声セグメントを判断
    • 過去10フレームのうち9割が音声だったら音声セグメント開始
    • 過去10フレームのうち9割が非音声だったら音声セグメント終了
    • 過去10フレームだけ見ればよいのでリングバッファとしてそれ以前の情報は破棄
  • 非音声の結果は破棄し、音声区間の情報だけ返却

これに対して以下のように変更をします。

  • 非音声の結果は破棄せず、双方の結果を返却
  • そのためリングバッファにはせずセグメントの終了が未確定のフレームはすべて記憶するよう修正
  • 音声セグメントの開始、終了の閾値を変更できるよう引数に追加
    • フレームの音声判定がvoice_trigger_on_thres以上の割合なら音声セグメント開始
    • フレームの音声判定がvoice_trigger_off_thres以下の割合なら音声セグメント終了
def vad_collector(sample_rate: int, frame_duration_ms: int,
    padding_duration_ms: int, vad: webrtcvad.Vad, frames: list[Frame],
    voice_trigger_on_thres: float=0.9, voice_trigger_off_thres: float=0.1) -> list[dict]:
    """音声非音声セグメント処理

    Args:
        sample_rate (int): 単位時間あたりのサンプル数[Hz]
        frame_duration_ms (int): フレーム長
        padding_duration_ms (int): ガード長
        vad (webrtcvad.Vad): _description_
        frames (list[Frame]): フレーム分割された音声データ
        voice_trigger_on_thres (float, optional): 音声セグメント開始と判断する閾値. Defaults to 0.9.
        voice_trigger_off_thres (float, optional): 音声セグメント終了と判断する閾値. Defaults to 0.1.

    Returns:
        list[dict]: セグメント結果
    """
    # ガードするフレーム数
    num_padding_frames = int(padding_duration_ms / frame_duration_ms)

    # バッファ(リングバッファではなくする)
    # ring_buffer = collections.deque(maxlen=num_padding_frames)
    frame_buffer = []

    # いま音声かどうかのトリガのステータス
    triggered = False

    voiced_frames = []
    vu_segments = []

    for frame in frames:
        is_speech = vad.is_speech(frame.bytes, sample_rate)
        frame_buffer.append((frame, is_speech))

        # 非音声セグメントの場合
        if not triggered:

            # 過去フレームのうち音声判定数を取得
            # 過去を見る数はnum_padding_frames個
            num_voiced = len([f for f, speech in frame_buffer[-num_padding_frames:] if speech])

            # 9割以上が音声の場合は音声にトリガする(立ち上がり)
            if num_voiced > voice_trigger_on_thres * num_padding_frames:
                triggered = True

                # num_padding_framesより前は非音声セグメントとする
                audio_data = b''.join([f.bytes for f, _ in frame_buffer[:-num_padding_frames]])
                vu_segments.append({"vad": 0, "audio_size": len(audio_data), "audio_data": audio_data})

                # num_padding_frames以降は音声セグメント終了時にまとめるため一旦保持
                for f, _ in frame_buffer[-num_padding_frames:]:
                    voiced_frames.append(f)
                frame_buffer = []

        # 音声セグメントの場合
        else:
            # フレームを保持
            voiced_frames.append(frame)

            # 過去フレームのうち非音声判定数を取得
            # 過去を見る数はnum_padding_frames個
            num_unvoiced = len([f for f, speech in frame_buffer[-num_padding_frames:] if not speech])

            # 9割以上が非音声の場合はトリガを落とす(立ち下がり)
            if num_unvoiced > (1 - voice_trigger_off_thres) * num_padding_frames:
                triggered = False

                # 音声セグメントをまとめる
                audio_data = b''.join([f.bytes for f in voiced_frames])
                vu_segments.append({"vad": 1, "audio_size": len(audio_data), "audio_data": audio_data})
                voiced_frames = []

                frame_buffer = []

    # 終了時に音声セグメントか非音声セグメントかどうかで処理を分ける
    if triggered:
        audio_data = b''.join([f.bytes for f in voiced_frames])
        vu_segments.append({"vad": 1, "audio_size": len(audio_data), "audio_data": audio_data})
    else:
        audio_data = b''.join([f.bytes for f, _ in frame_buffer])
        vu_segments.append({"vad": 0, "audio_size": len(audio_data), "audio_data": audio_data})

    return vu_segments

結果を見てみる

一旦はデフォルトのパラメータで結果を見ていきます。

今回はサンプルデータとして、サンプリングレート16kHz、モノラル、PCM 16bit、長さは20分のaudio.wavというデータを準備しておきます。

# wav読込
audio_data, sample_rate = read_wave("./audio.wav")

# VADクラス
vad = webrtcvad.Vad(0)

# フレーム分割
frames = frame_generator(30, audio_data, sample_rate)
frames = list(frames)

# セグメント結果
vu_segments = vad_collector(sample_rate, 30, 300, vad, frames)

結果はpolarsのデータフレームにまとめつつ、各セグメントのwavファイルを出力してみます。

# wavファイル格納先作成
segments_dir = pathlib.Path("./segments/")
segments_dir.mkdir(parents=True, exist_ok=True)

# レコード作成しつつwavファイル出力
for_df = []
for i, segment in enumerate(vu_segments):
    path = segments_dir.joinpath(f"segment-{i:03d}-vad{segment['vad']}.wav")
    write_wave(str(path), segment['audio_data'], sample_rate)

    for_df.append({
        "filename": path.name,
        "vad": segment["vad"],
        "duration_sec": segment["audio_size"]/2.0/sample_rate,
    })
df = pl.DataFrame(for_df)

秒数が小さい順に見てみます。

print(df.filter(pl.col("vad")==1).sort(pl.col("duration_sec"), descending=True)[-10:])

# shape: (10, 3)
# ┌──────────────────────┬─────┬──────────────┐
# │ filename             ┆ vad ┆ duration_sec │
# │ ---                  ┆ --- ┆ ---          │
# │ str                  ┆ i64 ┆ f64          │
# ╞══════════════════════╪═════╪══════════════╡
# │ segment-101-vad1.wav ┆ 1   ┆ 2.67         │
# │ segment-147-vad1.wav ┆ 1   ┆ 2.07         │
# │ segment-139-vad1.wav ┆ 1   ┆ 2.04         │
# │ segment-067-vad1.wav ┆ 1   ┆ 1.8          │
# │ …                    ┆ …   ┆ …            │
# │ segment-125-vad1.wav ┆ 1   ┆ 1.32         │
# │ segment-103-vad1.wav ┆ 1   ┆ 1.14         │
# │ segment-095-vad1.wav ┆ 1   ┆ 0.84         │
# │ segment-115-vad1.wav ┆ 1   ┆ 0.72         │
# └──────────────────────┴─────┴──────────────┘

元も短いsegment-115-vad1.wavを試聴してみましたが、不自然な切れ目にはなっていませんでした。小さい方はおおむね問題なさそうです。

次に秒数が多い順に見てみます。

print(df.sort(pl.col("duration_sec"), descending=True)[:10])

# shape: (10, 3)
# ┌──────────────────────┬─────┬──────────────┐
# │ filename             ┆ vad ┆ duration_sec │
# │ ---                  ┆ --- ┆ ---          │
# │ str                  ┆ i64 ┆ f64          │
# ╞══════════════════════╪═════╪══════════════╡
# │ segment-017-vad1.wav ┆ 1   ┆ 61.32        │
# │ segment-117-vad1.wav ┆ 1   ┆ 43.41        │
# │ segment-023-vad1.wav ┆ 1   ┆ 34.14        │
# │ segment-081-vad1.wav ┆ 1   ┆ 32.43        │
# │ …                    ┆ …   ┆ …            │
# │ segment-003-vad1.wav ┆ 1   ┆ 26.28        │
# │ segment-167-vad1.wav ┆ 1   ┆ 25.2         │
# │ segment-187-vad1.wav ┆ 1   ┆ 22.08        │
# │ segment-071-vad1.wav ┆ 1   ┆ 20.4         │
# └──────────────────────┴─────┴──────────────┘

最長が60秒程度となっており、これでも用途によっては良いのですが、実際にリアルタイム処理する際に60秒の遅延が発生する可能性があるということになります。

今回はこちらを短くする検討をしました。

VADのモードを変更する

VADモードは以下のようにVadクラス生成時に引数で変更できます。

# VADクラス
vad = webrtcvad.Vad(3)

最も音声判定が厳しい3に設定してみました。差分のない部分のコードは省略してセグメントの結果だけを大きい順に見てみます。

print(df.sort(pl.col("duration_sec"), descending=True)[:10])

# shape: (10, 3)
# ┌──────────────────────┬─────┬──────────────┐
# │ filename             ┆ vad ┆ duration_sec │
# │ ---                  ┆ --- ┆ ---          │
# │ str                  ┆ i64 ┆ f64          │
# ╞══════════════════════╪═════╪══════════════╡
# │ segment-073-vad1.wav ┆ 1   ┆ 26.85        │
# │ segment-031-vad1.wav ┆ 1   ┆ 26.28        │
# │ segment-297-vad1.wav ┆ 1   ┆ 25.08        │
# │ segment-029-vad1.wav ┆ 1   ┆ 21.84        │
# │ …                    ┆ …   ┆ …            │
# │ segment-011-vad1.wav ┆ 1   ┆ 15.63        │
# │ segment-153-vad1.wav ┆ 1   ┆ 14.91        │
# │ segment-335-vad1.wav ┆ 1   ┆ 14.52        │
# │ segment-017-vad1.wav ┆ 1   ┆ 13.74        │
# └──────────────────────┴─────┴──────────────┘

結構改善が見られ、30秒以下には収まりそうな雰囲気です。もう少し短くしたいところですね。

ガード区間の閾値を調整する

現状はガード区間は以下の設計になっています。(再掲)

  • フレーム毎の判定結果の継続性で最終的な音声セグメントを判断
  • 過去10フレームのうち9割が音声だったら音声セグメント開始
  • 過去10フレームのうち9割が非音声だったら音声セグメント終了

これは定性的な言い方をすると、立ち上がりは重たく、立ち下がりも重たいような動作となっています。

立ち上がりが軽いと音声ではないゴミが増えますので、こちらは重たいままでもよいのですが、立ち下がりは軽い方がちょっとした無音で分割することができます。

ですのでvoice_trigger_off_thres=0.1からvoice_trigger_off_thres=0.8として立ち下がりを軽くしてみます。

# セグメント結果
vu_segments = vad_collector(sample_rate, 30, 300, vad, frames, voice_trigger_off_thres=0.8)

差分のない部分のコードは省略してセグメントの結果だけを大きい順に見てみます。

print(df.sort(pl.col("duration_sec"), descending=True)[:10])

# shape: (10, 3)
# ┌──────────────────────┬─────┬──────────────┐
# │ filename             ┆ vad ┆ duration_sec │
# │ ---                  ┆ --- ┆ ---          │
# │ str                  ┆ i64 ┆ f64          │
# ╞══════════════════════╪═════╪══════════════╡
# │ segment-019-vad1.wav ┆ 1   ┆ 12.54        │
# │ segment-013-vad1.wav ┆ 1   ┆ 8.16         │
# │ segment-537-vad1.wav ┆ 1   ┆ 7.32         │
# │ segment-099-vad1.wav ┆ 1   ┆ 7.05         │
# │ …                    ┆ …   ┆ …            │
# │ segment-011-vad1.wav ┆ 1   ┆ 6.6          │
# │ segment-030-vad0.wav ┆ 0   ┆ 6.57         │
# │ segment-489-vad1.wav ┆ 1   ┆ 6.39         │
# │ segment-207-vad1.wav ┆ 1   ┆ 6.24         │
# └──────────────────────┴─────┴──────────────┘

15秒くらいまで短くすることができ、全体的にも短くすることができました。

短いセグメントはマージする

これで長いセグメントに対する対策はある程度できましたが、これにより短いセグメントも増えてしまいました。

ですので後処理として一定以下のセグメントは出てこないようにマージしてみたいと思います。

ここでは音声、非音声無関係にマージします。

# 初期化
vu_segments_merged = []
count = 0
while count < len(vu_segments):
    s = vu_segments[count].copy()

    # 3秒以下なら次のセグメントとマージ
    while s["audio_size"] < 3 * 2 * sample_rate:
        # 次のセグメントがない場合は強制終了
        if count == len(vu_segments) - 1:
            break

        # マージ処理
        s["audio_size"] = s["audio_size"] + vu_segments[count+1]["audio_size"]
        s["audio_data"] = s["audio_data"] + vu_segments[count+1]["audio_data"]
        count += 1

    # マージされたセグメントを格納
    vu_segments_merged.append(s)
    count += 1

再度wavファイル出力とdf作成を実行します。

# wavファイル格納先作成
segments_dir = pathlib.Path("./segments_merged/")
segments_dir.mkdir(parents=True, exist_ok=True)

# レコード作成しつつwavファイル出力
for_df = []
for i, segment in enumerate(vu_segments_merged):
    path = segments_dir.joinpath(f"segment-{i:03d}.wav")
    write_wave(str(path), segment['audio_data'], sample_rate)

    for_df.append({
        "filename": path.name,
        "duration_sec": segment["audio_size"]/2.0/sample_rate,
    })
df = pl.DataFrame(for_df)

秒数が小さい順に見てみます。

print(df.sort(pl.col("duration_sec"), descending=True)[-10:])

# shape: (10, 2)
# ┌─────────────────┬──────────────┐
# │ filename        ┆ duration_sec │
# │ ---             ┆ ---          │
# │ str             ┆ f64          │
# ╞═════════════════╪══════════════╡
# │ segment-056.wav ┆ 3.06         │
# │ segment-119.wav ┆ 3.06         │
# │ segment-138.wav ┆ 3.06         │
# │ segment-156.wav ┆ 3.06         │
# │ …               ┆ …            │
# │ segment-250.wav ┆ 3.03         │
# │ segment-063.wav ┆ 3.0          │
# │ segment-103.wav ┆ 3.0          │
# │ segment-245.wav ┆ 3.0          │
# └─────────────────┴──────────────┘

きちんと3秒以上にマージされています。

次に秒数が多い順に見てみます。

print(df.sort(pl.col("duration_sec"), descending=True)[:10])

# shape: (10, 2)
# ┌─────────────────┬──────────────┐
# │ filename        ┆ duration_sec │
# │ ---             ┆ ---          │
# │ str             ┆ f64          │
# ╞═════════════════╪══════════════╡
# │ segment-008.wav ┆ 12.99        │
# │ segment-058.wav ┆ 9.63         │
# │ segment-013.wav ┆ 9.36         │
# │ segment-005.wav ┆ 8.85         │
# │ …               ┆ …            │
# │ segment-185.wav ┆ 8.43         │
# │ segment-027.wav ┆ 8.22         │
# │ segment-030.wav ┆ 8.07         │
# │ segment-043.wav ┆ 8.01         │
# └─────────────────┴──────────────┘

マージにより大幅に増えてはいないことが確認できました。

まとめ

いかがでしたでしょうか。pydubsplit_on_silenceでは中途半端な箇所での分割も発生していましたが、py-webrtcvadを使った手法では思ったような分割を実装することができました。

こちらをWhisperなどの前段に使用すれば、ストリーミングで文字起こし処理も実現できそうです。そういった記事も後日作成しようと思います。

本記事が皆さまの参考になれば幸いです。