[中級編]LLMへ至る道~LSTMってやつがいるらしい~[10日目]

2023.12.10

みなさんこんにちは!クルトンです!

前日のブログは、RNNについてご紹介してきました。文脈を考慮できるモデルが登場した訳ですが、課題点があるという内容でした。

本日紹介するモデルはLong Short-Term Memory(以下ではLSTMと呼称)というものです。RNNでは長期間の依存関係に関する情報を保てないという課題があり、そちらを克服したモデルになります。

それでは、早速確認していきましょう!

LSTMについてイメージで掴む!

キーワードはRNN、1つのメモリセル、3つのゲートです。まずは下の図をご覧ください。

road-to-llm-advent-calendar-2023-10-01

順方向のLSTMです。RNNと同じく逆方向もありますが、順方向での処理内容をイメージできるようになればOKです。

RNNの図と見比べていただくと違いが分かりやすいかもしれません。LSTMでは出力と出力の間に、長期記憶を保つための工夫が使われています。まずは3つのゲートと1つのメモリセルなど、出力と出力の間にあるものについて、簡単に整理してみます。

まずは3つのゲートの共通項についてです。

  • 3つのゲートについて
    • どの情報を流すかどうかを選別
      • どの情報を流すか判断する時の「対象」が異なる
    • どのゲートもシグモイド関数を使っており、どの情報を流すべきかどうかは0から1の間の数値で判断する
  • 入力ゲート(Input Gate)
    • 現在の時点で見ている情報で、どれを次の処理へ流すかどうかを考える
  • 忘却ゲート(Forget Gate)
    • 過去の時点で見ている情報で、どれを次の処理へ流すかどうかを考える
    • 言い換えれば忘れて良い情報がどれかを選別している
  • 出力ゲート(Output Gate)
    • 未来の時点で見ている情報で、どれを次の処理へ流すかどうかを考える
    • 後の処理でも使う、覚えておいた方が良い情報を選別する
  • メモリセル
    • 長期記憶の役割で、今までの情報の入力とゲートの処理結果を保持
  • tanh(ハイパボリックタンジェント)
    • 入力された値を-1から1の範囲の値に変換する活性化関数
    • 3日目の活性化関数に関するブログで紹介したシグモイド関数が0から1なので、それとは値の範囲が異なるものだという理解でひとまずOK

次にメモリセルcの計算式についてです。

メモリセルc i = 入力ゲート × RNNでの出力y i + 忘却ゲート × メモリセルc i-1

上記の式では「LSTMによって一つ前の時刻で出力されたもの」と「入力された値(上図だとxの部分)」を使ってRNNで出力したものをyとしています。 yを使って、現時点で重要そうな情報がどれかを判定しているのが入力ゲートです。

次に、忘却ゲートとメモリセルを使っているところでは、過去に入力された情報で必要そうな情報がどれかを選別 しています。

つまり、現在入力されている値(新しい情報)と過去の記憶を元に、これまでの情報で重要なものがどれかというのを選別した履歴を持っているのがメモリセルです。 重要かどうかの選別には内部にシグモイド関数を持っている入力ゲートと忘却ゲートを使っています。

次にLSTMの出力についての式です。メモリセルの情報から出力ゲートを使って、次の出力時にどの情報を保持しておくと良いか選別しています。

LSTMの出力h i = 出力ゲート × tanh ( c i )

ここまでの内容をまとめると、長期保存したい情報をメモリセルで保存し、メモリセルの情報を使って3つのゲートではどの情報を覚えておくべきか、情報を伝えたり伝えないようにしたりと動的に調整をしています。

LSTMではRNNの直近の入力だけじゃない、これまでの入力を考慮して、出力をするというステップを踏んでいるという事です。

終わりに

本日のブログでは、LSTMについてご紹介してきました。イメージを元に、文脈考慮をしているモデルなんだなぁと理解していただければ幸いです。

ここまででWord2Vecから始まり、文脈を考慮したモデルについてご紹介してきました。

明日はLSTMを複数使ったモデルであるELMoについてご紹介いたします!

本日はここまで。それでは、明日もよければご覧ください!