音声認識モデル”Whisper”をストリーミング処理対応させる方法

音声認識モデル”Whisper”をストリーミング処理対応させる方法

Clock Icon2022.10.09

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんちには。

データアナリティクス事業本部 機械学習チームの中村です。

OpenAIがリリースしたWhisperについて、前回はtranscribeの内容を紐解きました。

Whisperが提供しているtranscribeのAPIは、バッチ処理のみに対応した構成となっており、リアルタイムに認識を試すのが難しくなっています。

そこで今回は、前回紐解いた結果を使ってストリーミング処理に対応させてみようと思います。

設計の概要

以下に設計の概要を図で示します。

前回ご紹介した通り、transcribeの中身は30秒単位で処理をしています(この単位を本記事ではフレームと呼びます)。

そして次の処理フレームは、前のフレームのタイムスタンプトークンの末尾から30秒となります。

そのためフレームをずらす長さ(これを本記事ではシフト長と呼びます)は動的になります。

そのため、設計方針は以下のようにしました。

  • ストリーミング処理用のクラスを作成し、前回の位置をメンバ変数として記憶する。
  • データを記憶するためのバッファを準備して、30秒分たまるまで待つ。
  • 30秒分たまった場合、認識処理を実行し、最後のタイムスタンプトークンから末尾までをバッファに残す。
  • 以降、30秒分がバッファにたまるまで待つの繰り返し。

送信されるストリーミングのフォーマットは、オーディオ信号のみの符号化なし、ヘッダー無しのバイナリデータとします。 (符号化圧縮した方が帯域は削減できますが、今回は簡単のため)

ヘッダー無しですが、フォーマットはsinged 16bit Little Endian形式、サンプリングレートは16,000Hzのもののみを受け付けます。

また言語判定処理は簡単のため行わず、あらかじめ日本語が入力される前提で実装していきます。

完成物

ストリーミング処理クラスの全体は長くなるので、先に完成物を示しておきます。

以下のメニューを開いてご覧ください。

WhiperStreamingクラス (クリックして展開)
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
import warnings

import torch
import numpy as np
import tqdm

from whisper.audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from whisper.utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
from whisper import load_model
from whisper.model import Whisper

class WhisperStreaming():

    def __init__(self):

        # 引数から移植
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = load_model("large", device=self.device)

        self.verbose = True
        self.temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
        self.compression_ratio_threshold: Optional[float] = 2.4
        self.logprob_threshold: Optional[float] = -1.0
        self.no_speech_threshold: Optional[float] = 0.6
        self.condition_on_previous_text: bool = True

        # 今回はlanguageを日本語に固定、taskもtranscribeに固定
        self.decode_options = {
            "language": "ja",
            "task": "transcribe",
            "fp16": True,
        }

        # transcribeのループ外のものを移植
        self.dtype = torch.float16 if self.decode_options.get("fp16", True) else torch.float32
        if self.model.device == torch.device("cpu"):
            if torch.cuda.is_available():
                warnings.warn("Performing inference on CPU when CUDA is available")
            if self.dtype == torch.float16:
                warnings.warn("FP16 is not supported on CPU; using FP32 instead")
                self.dtype = torch.float32

        if self.dtype == torch.float32:
            self.decode_options["fp16"] = False

        language = self.decode_options["language"]
        task = self.decode_options.get("task", "transcribe")
        self.tokenizer = get_tokenizer(self.model.is_multilingual, language=language, task=task)

        self.seek = 0
        self.input_stride = exact_div(
            N_FRAMES, self.model.dims.n_audio_ctx
        )  # mel frames per output token: 2
        self.time_precision = (
            self.input_stride * HOP_LENGTH / SAMPLE_RATE
        )  # time per output token: 0.02 (seconds)
        self.all_tokens = []
        self.all_segments = []
        self.prompt_reset_since = 0

        # バッファ記憶用のメンバ変数
        self.buffer = b''

    def set_data(self, audio: bytes):

        # 前回の残り分と連結する
        self.buffer = self.buffer + audio

        # 30秒以下になるまでループする
        while len(self.buffer) >= SAMPLE_RATE*CHUNK_LENGTH*2:

            # 30秒分バッファから取得
            frame = self.buffer[:SAMPLE_RATE*CHUNK_LENGTH*2]

            # byteデータを2byteデータに詰めなおして最大値で割って-1~1に正規化する
            frame = np.frombuffer(frame, np.int16).flatten().astype(np.float32) / 32768.0

            # 処理前の場所を記憶しておく
            previous_seek = self.seek

            # 30秒データを処理する
            self.set_frame(frame)

            # 処理済みとなった時刻がself.seekに記載されるため
            # 処理済みをバッファから消す更新を行う
            self.buffer = self.buffer[(self.seek - previous_seek)*HOP_LENGTH*2:]

        return

    def set_frame(self, audio: bytes):

        # 音響特徴量計算
        mel = log_mel_spectrogram(audio)
        mel = mel.unsqueeze(0)

        # 現時刻のオフセットをseekから計算
        timestamp_offset = float(self.seek * HOP_LENGTH / SAMPLE_RATE)
        segment = pad_or_trim(mel, N_FRAMES).to(self.device).to(self.dtype)
        segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE

        self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since:]
        result = self.decode_with_fallback(segment)[0]
        tokens = torch.tensor(result.tokens)

        # 有音無音判定
        if self.no_speech_threshold is not None:
            # no voice activity check
            should_skip = result.no_speech_prob > self.no_speech_threshold
            if self.logprob_threshold is not None and result.avg_logprob > self.logprob_threshold:
                # don't skip if the logprob is high enough, despite the no_speech_prob
                should_skip = False

            if should_skip:
                self.seek += segment.shape[-1]  # fast-forward to the next segment boundary
                return

        # トークナイザでデコード
        timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)

        # タイムスタンプトークンによる区間処理
        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
        if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
            last_slice = 0
            for current_slice in consecutive:
                sliced_tokens = tokens[last_slice:current_slice]
                start_timestamp_position = (
                    sliced_tokens[0].item() - self.tokenizer.timestamp_begin
                )
                end_timestamp_position = (
                    sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
                )
                self.add_segment(
                    start=timestamp_offset + start_timestamp_position * self.time_precision,
                    end=timestamp_offset + end_timestamp_position * self.time_precision,
                    text_tokens=sliced_tokens[1:-1],
                    result=result,
                )
                last_slice = current_slice
            last_timestamp_position = (
                tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
            )
            self.seek += last_timestamp_position * self.input_stride
            self.all_tokens.extend(tokens[: last_slice + 1].tolist())

        # 区間が無い場合の処理
        else:
            duration = segment_duration
            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
            if len(timestamps) > 0:
                # no consecutive timestamps but it has a timestamp; use the last one.
                # single timestamp at the end means no speech after the last timestamp.
                last_timestamp_position = timestamps[-1].item() - self.tokenizer.timestamp_begin
                duration = last_timestamp_position * self.time_precision

            self.add_segment(
                start=timestamp_offset,
                end=timestamp_offset + duration,
                text_tokens=tokens,
                result=result,
            )

            self.seek += segment.shape[-1]
            self.all_tokens.extend(tokens.tolist())

        if not self.condition_on_previous_text or result.temperature > 0.5:
            # do not feed the prompt tokens if a high temperature was used
            self.prompt_reset_since = len(self.all_tokens)

        return

    def decode_with_fallback(self, segment: torch.Tensor) -> List[DecodingResult]:
        temperatures = [self.temperature] if isinstance(self.temperature, (int, float)) else self.temperature
        kwargs = {**self.decode_options}
        t = temperatures[0]
        if t == 0:
            best_of = kwargs.pop("best_of", None)
        else:
            best_of = kwargs.get("best_of", None)

        options = DecodingOptions(**kwargs, temperature=t)
        results = self.model.decode(segment, options)

        kwargs.pop("beam_size", None)  # no beam search for t > 0
        kwargs.pop("patience", None)  # no patience for t > 0
        kwargs["best_of"] = best_of  # enable best_of for t > 0
        for t in temperatures[1:]:
            needs_fallback = [
                self.compression_ratio_threshold is not None
                and result.compression_ratio > self.compression_ratio_threshold
                or self.logprob_threshold is not None
                and result.avg_logprob < self.logprob_threshold
                for result in results
            ]
            if any(needs_fallback):
                options = DecodingOptions(**kwargs, temperature=t)
                retries = self.model.decode(segment[needs_fallback], options)
                for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
                    results[original_index] = retries[retry_index]

        return results

    def add_segment(self,
        *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
    ):
        text = self.tokenizer.decode([token for token in text_tokens if token < self.tokenizer.eot])
        if len(text.strip()) == 0:  # skip empty text output
            return

        self.all_segments.append(
            {
                "id": len(self.all_segments),
                "seek": self.seek,
                "start": start,
                "end": end,
                "text": text,
                "tokens": result.tokens,
                "temperature": result.temperature,
                "avg_logprob": result.avg_logprob,
                "compression_ratio": result.compression_ratio,
                "no_speech_prob": result.no_speech_prob,
            }
        )
        if self.verbose:
            print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
        
        return

    def finalized(self):
        frame = self.buffer

        # byteデータを2byteデータに詰めなおして最大値で割って-1~1に正規化する
        frame = np.frombuffer(frame, np.int16).flatten().astype(np.float32) / 32768.0

        # 内部で0詰めして30秒データとしてから処理
        self.set_frame(frame)

    def get_result(self):
        return dict(
            text=self.tokenizer.decode(self.all_tokens), 
            segments=self.all_segments
        )

実行環境

今回もGoogle Colab環境を使用します。ハードウェアなどの主な情報は以下の通りです。

  • GPU: Tesla T4
  • CUDA: 11.1
  • メモリ: 13GB(標準メモリタイプ)

主なライブラリのバージョンは以下となります。

  • transformers: 4.22.2
  • whisper: 1.0

使用するデータ

前回同様、ソースは社内の勉強会で自身が発話したデータを用います。

そのうち、30秒を切り出した音声データをaudio_sample.wavとして使用します。

切り出し方法については、以下の以前の記事等を参照ください。

準備

実装の前段階としてセットアップ等を実施します。

Whisperのセットアップ

whisperのインストールをします。

!pip install git+https://github.com/openai/whisper.git

インポートします。

import whisper

モデルはlargeをロードしておきます。

model = whisper.load_model("large")

まずは期待する結果を確認

まずは通常のtranscribeでバッチ処理をして結果を確認します。

  • コード
model = whisper.load_model("large")
  • 出力

[00:00.000 --> 00:07.000] 画面共有
[00:07.000 --> 00:26.000] 画面見えてますでしょうか
[00:26.000 --> 00:30.000] こっちか
[00:30.000 --> 00:31.000] 見えてます
[00:31.000 --> 00:37.000] ありがとうございます
[00:37.000 --> 00:41.000] NNアーキテクチャ勉強会ということで
[00:41.000 --> 00:43.000] 物体検出編
[00:43.000 --> 00:48.000] 前回3月の16日だったかな
[00:48.000 --> 00:51.000] そこでやった続きということで
[00:51.000 --> 00:58.000] 今回は物体検出の方やっていきます
[00:58.000 --> 01:03.000] スレッドとかあるのかな
[01:03.000 --> 01:05.000] ないかな
[01:05.000 --> 01:07.000] 作ります
[01:07.000 --> 01:15.000] ありがとうございます
[01:15.000 --> 01:17.000] はじめにはじめにはすごい余談なんで
[01:17.000 --> 01:21.000] 飛ばしましょうかね
[01:21.000 --> 01:23.000] いろいろしてますけど
[01:23.000 --> 01:26.000] 16時間ものを食べないっていうのを
[01:26.000 --> 01:29.000] 最近やってますっていう話です
[01:29.000 --> 01:35.000] 興味あったら見ていてください
[01:35.000 --> 01:38.000] ちょっと目次の方は
[01:38.000 --> 01:39.000] 前回から日が空いてるので
[01:39.000 --> 01:44.000] 前回のおさらいの方をまずやっていきます
[01:44.000 --> 01:49.000] で、物体検出の仕組みを説明して
[01:49.000 --> 01:53.000] まあそうですね仕組みを説明してその中で
[01:53.000 --> 01:54.000] ヨロっていう
[01:54.000 --> 01:59.000] 物体検出の有名なアーキテクチャがあるんですけれども
[01:59.000 --> 02:02.000] そちらの方がどういう
[02:02.000 --> 02:04.000] どういうシリーズがあるか
[02:04.000 --> 02:06.000] すごい種類がたくさんあるので
[02:06.000 --> 02:09.000] そこらへんを説明しながら
[02:09.000 --> 02:13.000] そこまでに説明してる仕組みと
[02:13.000 --> 02:16.000] どういうふうな位置づけかってのを
[02:16.000 --> 02:19.000] 説明しようと思ってます
[02:19.000 --> 02:22.000] あと実際物体検出を使うときのフレームワークとかの話も
[02:22.000 --> 02:33.000] 最後の方に見れようかなと思っています
[02:33.000 --> 02:36.000] じゃあ前回のおさらいです
[02:36.000 --> 02:38.000] まずニューラルネットワークのところを
[02:38.000 --> 02:43.000] 1枚のスライドにまとめてますけれども
[02:43.000 --> 02:47.000] 線形結合をたくさんするやつが
[02:47.000 --> 02:50.000] 多層に重なったものということで
[02:50.000 --> 02:52.000] そして前回は
[02:52.000 --> 02:55.000] 天気、気温と湿度で天気を予測する
[02:55.000 --> 02:57.000] みたいな例を挙げたんですけど
[02:57.000 --> 03:09.000] そういう線形の

この結果を出力できるようなストリーミング処理クラスを作成します。

ストリーミング処理クラス実装

基本的にはtranscribeから処理を移植することで実現できます。

おおむね以下のようなAPI構成となります。

class WhisperStreaming():

    def __init__(self, verbose=False):

        # 引数や設定値は今回はここで固定で入れる
        # transcribeのwhileループ外の処理をここにいれる
        # バッファリング用の変数もここで定義

    def set_data(self, audio: bytes):

        # 受け取ったデータを30秒単位に整形する
        # 整形したものはset_frameで処理する

    def set_frame(self, audio: bytes):

        # 30秒単位の処理
        # transcribeのwhileループ内の処理をここに詰める
        # その他、transcribeではバッチ処理で会ったメルスペクトログラム計算も
        # ここで30秒単位で実施する

    def decode_with_fallback(self, segment: torch.Tensor) -> List[DecodingResult]:

        # 内部処理用
        # そのままtranscribeから移植

    def add_segment(self,
        *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
    ):

        # 内部処理用
        # そのままtranscribeから移植

    def finalized(self):

        # 末尾のデータを処理させるためにコールする
        # 30秒以下でもそのままset_frameに渡して0詰めして処理する

    def get_result(self):

        # 全ての処理結果を取得する

以降で一つ一つ説明していきます。

コンストラクタ

まずはコンストラクタです。transcribeの引数や設定値、ループ外からの移植が主となります。

    def __init__(self):

        """
        引数から移植
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = load_model("large", device=self.device)

        self.verbose = True
        self.temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
        self.compression_ratio_threshold: Optional[float] = 2.4
        self.logprob_threshold: Optional[float] = -1.0
        self.no_speech_threshold: Optional[float] = 0.6
        self.condition_on_previous_text: bool = True

        # 今回はlanguageを日本語に固定、taskもtranscribeに固定
        self.decode_options = {
            "language": "ja",
            "task": "transcribe",
            "fp16": True,
        }

        """
        transcribeのループ外のものを移植
        """
        self.dtype = torch.float16 if self.decode_options.get("fp16", True) else torch.float32
        if self.model.device == torch.device("cpu"):
            if torch.cuda.is_available():
                warnings.warn("Performing inference on CPU when CUDA is available")
            if self.dtype == torch.float16:
                warnings.warn("FP16 is not supported on CPU; using FP32 instead")
                self.dtype = torch.float32

        if self.dtype == torch.float32:
            self.decode_options["fp16"] = False

        language = self.decode_options["language"]
        task = self.decode_options.get("task", "transcribe")
        self.tokenizer = get_tokenizer(self.model.is_multilingual, language=language, task=task)

        self.seek = 0
        self.input_stride = exact_div(
            N_FRAMES, self.model.dims.n_audio_ctx
        )  # mel frames per output token: 2
        self.time_precision = (
            self.input_stride * HOP_LENGTH / SAMPLE_RATE
        )  # time per output token: 0.02 (seconds)
        self.all_tokens = []
        self.all_segments = []
        self.prompt_reset_since = 0

        """
        バッファ記憶用のメンバ変数
        """
        self.buffer = b''

なお、オリジナルで指定できる引数や設定値は今回簡単のため固定で入れています。必要に応じて引数化するなどのカスタマイズが可能です。

set_data

ここが実際にストリーミングのデータを処理する受け口になります。 主に30秒の整形処理を行います。

    def set_data(self, audio: bytes):

        # 前回の残り分と連結する
        self.buffer = self.buffer + audio

        # 30秒以下になるまでループする
        while len(self.buffer) >= SAMPLE_RATE*CHUNK_LENGTH*2:

            # 30秒分バッファから取得
            frame = self.buffer[:SAMPLE_RATE*CHUNK_LENGTH*2]

            # byteデータを2byteデータに詰めなおして最大値で割って-1~1に正規化する
            frame = np.frombuffer(frame, np.int16).flatten().astype(np.float32) / 32768.0

            # 処理前の場所を記憶しておく
            previous_seek = self.seek

            # 30秒データを処理する
            self.set_frame(frame)

            # 処理済みとなった時刻がself.seekに記載されるため
            # 処理済みをバッファから消す更新を行う
            self.buffer = self.buffer[(self.seek - previous_seek)*HOP_LENGTH*2:]

        return

ここでキモとなるのは、self.seekの動きです。

self.seekはオリジナルのtranscribeにも存在しますが、音響特徴量(メル周波数スペクトログラム)領域での時間ポイントを指すため、 時間データ領域に戻すため、HOP_LENGTH*2を書けてバッファ更新に使用します。

set_frame

次に30秒単位の処理用のset_frameを見ていきます。

    def set_frame(self, audio: bytes):

        # 音響特徴量計算
        mel = log_mel_spectrogram(audio)
        mel = mel.unsqueeze(0)

        # 現時刻のオフセットをseekから計算
        timestamp_offset = float(self.seek * HOP_LENGTH / SAMPLE_RATE)
        segment = pad_or_trim(mel, N_FRAMES).to(self.device).to(self.dtype)
        segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE

        self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since:]
        result = self.decode_with_fallback(segment)[0]
        tokens = torch.tensor(result.tokens)

        # 有音無音判定
        if self.no_speech_threshold is not None:
            # no voice activity check
            should_skip = result.no_speech_prob > self.no_speech_threshold
            if self.logprob_threshold is not None and result.avg_logprob > self.logprob_threshold:
                # don't skip if the logprob is high enough, despite the no_speech_prob
                should_skip = False

            if should_skip:
                self.seek += segment.shape[-1]  # fast-forward to the next segment boundary
                return

        # トークナイザでデコード
        timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)

        # タイムスタンプトークンによる区間処理
        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
        if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
            last_slice = 0
            for current_slice in consecutive:
                sliced_tokens = tokens[last_slice:current_slice]
                start_timestamp_position = (
                    sliced_tokens[0].item() - self.tokenizer.timestamp_begin
                )
                end_timestamp_position = (
                    sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
                )
                self.add_segment(
                    start=timestamp_offset + start_timestamp_position * self.time_precision,
                    end=timestamp_offset + end_timestamp_position * self.time_precision,
                    text_tokens=sliced_tokens[1:-1],
                    result=result,
                )
                last_slice = current_slice
            last_timestamp_position = (
                tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
            )
            self.seek += last_timestamp_position * self.input_stride
            self.all_tokens.extend(tokens[: last_slice + 1].tolist())

        # 区間が無い場合の処理
        else:
            duration = segment_duration
            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
            if len(timestamps) > 0:
                # no consecutive timestamps but it has a timestamp; use the last one.
                # single timestamp at the end means no speech after the last timestamp.
                last_timestamp_position = timestamps[-1].item() - self.tokenizer.timestamp_begin
                duration = last_timestamp_position * self.time_precision

            self.add_segment(
                start=timestamp_offset,
                end=timestamp_offset + duration,
                text_tokens=tokens,
                result=result,
            )

            self.seek += segment.shape[-1]
            self.all_tokens.extend(tokens.tolist())

        if not self.condition_on_previous_text or result.temperature > 0.5:
            # do not feed the prompt tokens if a high temperature was used
            self.prompt_reset_since = len(self.all_tokens)

        return

ほぼtranscribeのwhileループ内処理をそのまま移植しています。

もちろんコンストラクタに移動した変数は、メンバ変数としてアクセスするためにselfを付けて参照します。

内部関数

decode_with_fallbackもほぼそのままの移植です。

    def decode_with_fallback(self, segment: torch.Tensor) -> List[DecodingResult]:
        temperatures = [self.temperature] if isinstance(self.temperature, (int, float)) else self.temperature
        kwargs = {**self.decode_options}
        t = temperatures[0]
        if t == 0:
            best_of = kwargs.pop("best_of", None)
        else:
            best_of = kwargs.get("best_of", None)

        options = DecodingOptions(**kwargs, temperature=t)
        results = self.model.decode(segment, options)

        kwargs.pop("beam_size", None)  # no beam search for t > 0
        kwargs.pop("patience", None)  # no patience for t > 0
        kwargs["best_of"] = best_of  # enable best_of for t > 0
        for t in temperatures[1:]:
            needs_fallback = [
                self.compression_ratio_threshold is not None
                and result.compression_ratio > self.compression_ratio_threshold
                or self.logprob_threshold is not None
                and result.avg_logprob < self.logprob_threshold
                for result in results
            ]
            if any(needs_fallback):
                options = DecodingOptions(**kwargs, temperature=t)
                retries = self.model.decode(segment[needs_fallback], options)
                for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
                    results[original_index] = retries[retry_index]

        return results

add_segmentもほぼそのままの移植です。

    def add_segment(self,
        *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
    ):
        text = self.tokenizer.decode([token for token in text_tokens if token < self.tokenizer.eot])
        if len(text.strip()) == 0:  # skip empty text output
            return

        self.all_segments.append(
            {
                "id": len(self.all_segments),
                "seek": self.seek,
                "start": start,
                "end": end,
                "text": text,
                "tokens": result.tokens,
                "temperature": result.temperature,
                "avg_logprob": result.avg_logprob,
                "compression_ratio": result.compression_ratio,
                "no_speech_prob": result.no_speech_prob,
            }
        )
        if self.verbose:
            print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
        
        return

finalized

すべてのデータをset_dataで与えたのち、あまった30秒以下のデータに対して処理を行うためのAPIです。

そのままデータをset_frameに与えれば、内部で0詰めをして処理をします。

    def finalized(self):
        frame = self.buffer

        # byteデータを2byteデータに詰めなおして最大値で割って-1~1に正規化する
        frame = np.frombuffer(frame, np.int16).flatten().astype(np.float32) / 32768.0

        # 内部で0詰めして30秒データとしてから処理
        self.set_frame(frame)

get_result

最後に認識した結果をすべて取得するAPIを実装しておきます。

    def get_result(self):
        return dict(
            text=self.tokenizer.decode(self.all_tokens), 
            segments=self.all_segments
        )

動作確認

動作確認用に、音声データをbytesデータとして読み込んでおきます。

  • コード
import ffmpeg

out, _ = (
    ffmpeg.input("output.wav", threads=0)
    .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
    .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
)
print(type(out))
print(len(out))
  • 出力
<class 'bytes'>
5760000

次にインスタンス化して、データを5秒分ずつset_dataしていきます。

  • コード
whisper_streaming = WhisperStreaming(verbose=True)

sender_buffer = out

while len(sender_buffer) > 0:
    set_data = sender_buffer[:SAMPLE_RATE*5*2]
    whisper_streaming.set_data(set_data)
    sender_buffer = sender_buffer[SAMPLE_RATE*5*2:]
whisper_streaming.finalized()
  • 出力

[00:00.000 --> 00:24.520] 画面共有
[00:24.520 --> 00:26.120] 画面見えてますでしょうか
[00:26.120 --> 00:30.120] こっちか
[00:30.120 --> 00:31.120] 見えてます
[00:31.120 --> 00:33.120] ありがとうございます
[00:33.120 --> 00:41.120] NNアーキテクチャ勉強会ということで
[00:41.120 --> 00:43.120] 物体検出編
[00:43.120 --> 00:48.120] 前回3月の16日だったかな
[00:48.120 --> 00:51.120] そこでやった続きということで
[00:51.120 --> 00:58.120] 次は物体検出の方やっていきます
[00:58.120 --> 01:03.120] スレッドとかあるのかな
[01:03.120 --> 01:05.120] ないかな
[01:05.120 --> 01:07.120] 作ります
[01:07.120 --> 01:15.120] ありがとうございます
[01:15.120 --> 01:16.120] はじめに
[01:16.120 --> 01:18.120] はじめにがすごい余談なんで
[01:18.120 --> 01:21.120] 飛ばしましょうかね
[01:21.120 --> 01:23.120] いろいろしてますけど
[01:23.120 --> 01:26.120] 16時間ものを食べないっていうのを
[01:26.120 --> 01:29.120] 最近やってますっていう話です
[01:29.120 --> 01:35.120] 興味あったら見ていてください
[01:35.120 --> 01:38.120] ちょっと目次の方は
[01:38.120 --> 01:39.120] 前回から日が空いてるので
[01:39.120 --> 01:44.120] 前回のおさらいの方をまずやっていきます
[01:44.120 --> 01:49.120] で、物体検出の仕組みを説明して
[01:49.120 --> 01:51.120] まあそうですね仕組みを説明して
[01:51.120 --> 01:53.120] その中で
[01:53.120 --> 01:54.120] ヨロっていう
[01:54.120 --> 01:59.120] 物体検出の有名なアーキテクチャがあるんですけれども
[01:59.120 --> 02:02.120] そちらの方がどういう
[02:02.120 --> 02:03.120] どういうシリーズがあるか
[02:03.120 --> 02:06.120] すごい種類がたくさんあるので
[02:06.120 --> 02:09.120] そこらへんを説明しながら
[02:09.120 --> 02:13.120] そこまでに説明してる仕組みと
[02:13.120 --> 02:15.120] どういうふうな位置づけかってのを
[02:15.120 --> 02:19.120] 説明しようと思ってます
[02:19.120 --> 02:22.120] あと実際物体検出を使うときのフレームワークとかの話も
[02:22.120 --> 02:33.120] 最後の方に見れようかなと思っています
[02:33.120 --> 02:36.120] じゃあ前回のおさらいです
[02:36.120 --> 02:38.120] まずニューラルネットワークのところを
[02:38.120 --> 02:43.120] 1枚のスライドにまとめてますけれども
[02:43.120 --> 02:47.120] 線形結合をたくさんするやつが
[02:47.120 --> 02:50.120] 多層に重なったものということで
[02:50.120 --> 02:52.120] そして前回は
[02:52.120 --> 02:55.120] 天気、気温と湿度で天気を予測する
[02:55.120 --> 02:57.120] みたいな例を挙げたんですけど

ほぼ同じ結果がログとして得られています。

少々違う部分があるのは、音響特徴量をバッチ処理しているか、30秒単位で処理しているかの影響のためと考えられます。

結果は以下のようにまとめて取得することもできます。

  • コード
whisper_streaming.get_result()["text"]
  • 出力

画面共有画面見えてますでしょうかこっちか見えてますありがとうございますNNアーキテクチャ勉強会ということで物体検出編前回3月の16日だったかなそこでやった続きということで次は物体検出の方やっていきますスレッドとかあるのかなないかな作りますありがとうございますはじめにはじめにがすごい余談なんで飛ばしましょうかねいろいろしてますけど16時間ものを食べないっていうのを最近やってますっていう話です興味あったら見ていてくださいちょっと目次の方は前回から日が空いてるので前回のおさらいの方をまずやっていきますで、物体検出の仕組みを説明してまあそうですね仕組みを説明してその中でヨロっていう物体検出の有名なアーキテクチャがあるんですけれどもそちらの方がどういうどういうシリーズがあるかすごい種類がたくさんあるのでそこらへんを説明しながらそこまでに説明してる仕組みとどういうふうな位置づけかってのを説明しようと思ってますあと実際物体検出を使うときのフレームワークとかの話も最後の方に見れようかなと思っていますじゃあ前回のおさらいですまずニューラルネットワークのところを1枚のスライドにまとめてますけれども線形結合をたくさんするやつが多層に重なったものということでそして前回は天気、気温と湿度で天気を予測するみたいな例を挙げたんですけど

まとめ

いかがでしたでしょうか。

Whisperはバッチ推論のみに対応していましたが、ストリーミング対応できるようになったことで、結果をリアルタイムで返すような構成も可能となりました。 これで更に適用の幅が広がったと思います。

本記事がwhisperを活用してみようと思われる方の参考になれば幸いです。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.