機械学習 基礎の基礎 – 誤差逆伝搬法 –
Introduction
たまに思い出したように書く機械学習基礎の記事です。
以前の記事では勾配降下法について説明しました。
今回は、ニューラルネットワークの学習に欠かせない「誤差逆伝播法(バックプロパゲーション)」について解説します。
Back Propagation?
誤差逆伝播法は、ニューラルネットワークの学習において、出力と目標値の誤差を入力側に向かって逆方向に伝播させ、各層の重み(パラメータ)を効率的に調整する手法です。
以前説明した勾配降下法では、パラメータを更新するために「勾配(偏微分)」が必要でした。誤差逆伝播法は、この勾配を効率的に計算するための手法です。
処理フローは以下のような感じです。
- まず順方向に計算を行い、出力を得る(順伝播)
- 出力と目標値の誤差を計算する
- その誤差を逆方向に伝播させながら、各パラメータの影響度(偏微分)を計算する(逆伝播)
- 計算した偏微分を使って勾配降下法でパラメータを更新する
DQ3 Damage Calculation Formula
たまたま見かけたので、
FC版のDQ3ダメージ計算を例に解説していきます。
以下がダメージ計算式です。
ダメージ = { 攻撃力 - (敵の守備力 / 2) } × { 99 + <0〜54の乱数> } / 256
この計算式は以下の要素で構成されています。
- 攻撃力:攻撃側のパラメータ
- 敵の守備力:防御側のパラメータ
- 乱数:0〜54の範囲でダメージに変動を与える値
ここでの乱数は実際のゲーム内でダメージに変動を与えるものです。
本記事では固定値として扱います。
Calculation Examples
まず、順伝播(通常の計算)を行ってみましょう。
パラメータは以下で設定します。
- 攻撃力 = 150
- 敵の守備力 = 100
- 乱数 = 30(ここでは固定値)
順伝播の計算過程
Step 1: 守備力による減算
基本ダメージ = 150 - (100 / 2) = 150 - 50 = 100
Step 2: 倍率の計算
倍率 = (99 + 30) / 256 = 129 / 256 ≈ 0.5039
Step 3: 最終ダメージ
ダメージ = 100 × 0.5039 ≈ 50.39
なので、この条件だと「たたかう」を選択したときに
敵に50ダメージを与えることができます。
Application of Back Propagation
ここで、目標とするダメージが60だったとしましょう。
実際のダメージとの差(誤差)を使って、
攻撃力と守備力をどう調整すべきかを計算します。
1. 誤差の計算
誤差 = 実際のダメージ - 目標ダメージ
= 50.39 - 60
= -9.61
この誤差が負の値なので、ダメージを増やす必要があります。
2. 各パラメータの影響度を計算(偏微分)
次に、攻撃力と守備力がダメージにどれだけ影響するかを偏微分で計算します。
攻撃力に対する偏微分
攻撃力を1増やしたときのダメージの変化率:
∂ダメージ/∂攻撃力 = (99 + 乱数) / 256
= 129 / 256
≈ 0.5039
つまり、攻撃力が1増えると、ダメージは約0.5039増加します。
守備力に対する偏微分
守備力を1増やしたときのダメージの変化率:
∂ダメージ/∂守備力 = -0.5 × (99 + 乱数) / 256
= -0.5 × 129 / 256
≈ -0.2520
守備力が1増えると、ダメージは約0.2520減少します(マイナスなので減少)。
3. パラメータの更新
誤差と偏微分を使って、各パラメータをどれだけ調整すべきかを計算します:
攻撃力の更新量 = 学習率 × (-誤差) × ∂ダメージ/∂攻撃力
= 0.1 × 9.61 × 0.5039
≈ 0.484
新しい攻撃力 = 150 + 0.484 = 150.484
守備力の更新量 = 学習率 × (-誤差) × ∂ダメージ/∂守備力
= 0.1 × 9.61 × (-0.2520)
≈ -0.242
新しい守備力 = 100 - 0.242 = 99.758
※学習率(ここでは0.1)は、一度にどれだけパラメータを更新するかを決める係数
ちなみに、ダメージ計算を計算グラフで表現するとこんな感じです。
[攻撃力] ─┐
├─→ [基本ダメージ] ─┐
[守備力] ─┘ ├─→ [最終ダメージ]
│
[乱数] ────→ [倍率] ─────────┘
順伝播では左から右へ計算が進みますが、誤差逆伝播では右から左へ誤差と偏微分の情報が伝わっていきます。
Chain Rule
誤差逆伝播法の核心は「連鎖律(Chain Rule)」にあります。
連鎖律とは、合成関数の微分を計算する法則で、複雑な計算を段階的に分解して微分を求めることができます。
連鎖律の必要性
ダメージ計算を見てみると以下のような多段階の計算をしています。
- 守備力から基本ダメージを計算
- 基本ダメージと倍率から最終ダメージを計算
このようになっているため、守備力が最終ダメージに
どう影響するかを直接計算するのは面倒です。
しかし、連鎖律を使えば各段階の微分を掛け合わせることで計算できます。
具体的な計算例
ダメージ計算の流れを改めて整理すると、以下のような多段階の計算になっています。
敵の守備力(100) → 基本ダメージ(100) → 最終ダメージ(50.39)
この計算の流れを見ていきましょう。
Step 1: 基本ダメージの計算
まず、攻撃力から守備力の影響を差し引いて基本ダメージを求めます。
敵の守備力の1/2が防御効果として機能します。
基本ダメージ = 攻撃力 - (守備力 / 2)
= 150 - (100 / 2)
= 100
この段階で、守備力100は実質的に50の防御効果を発揮し、
攻撃力150から50を引いた100が基本ダメージとなります。
Step 2: 最終ダメージの計算
次に、基本ダメージに倍率を掛けて最終ダメージを計算します。
倍率は乱数要素を含む式で決まりますが、今回は固定値として扱います。
最終ダメージ = 基本ダメージ × 倍率
= 100 × 0.5039
= 50.39
倍率0.5039を掛けることで、基本ダメージ100が約半分の50.39になります。
この2段階の計算により、守備力が最終ダメージに
間接的に影響を与えていることがわかります。
連鎖律による偏微分の計算
先ほど見たように、守備力から最終ダメージまでは2段階の計算をしています。
連鎖律を使うと、この2段階の偏微分を掛け合わせることで、
守備力が最終ダメージに与える影響を求めることができます。
∂最終ダメージ/∂守備力 = ∂最終ダメージ/∂基本ダメージ × ∂基本ダメージ/∂守備力
この式は「最終ダメージの守備力に対する変化率」を、
「最終ダメージの基本ダメージに対する変化率」と
「基本ダメージの守備力に対する変化率」の積で表しています。
各部分の計算
各部分を個別に計算していきましょう。
- ∂最終ダメージ/∂基本ダメージ
これは、基本ダメージが最終ダメージに与える影響を示しています。
最終ダメージは基本ダメージに倍率を掛けた値なので、
基本ダメージが1増えると、最終ダメージは倍率分だけ増えます。
最終ダメージ = 基本ダメージ × 倍率
∂最終ダメージ/∂基本ダメージ = 倍率 = 0.5039
つまり、基本ダメージが1増えると最終ダメージは0.5039増えます。
- ∂基本ダメージ/∂守備力
次は、守備力が基本ダメージに与える影響です。
基本ダメージの計算式を見ると、守備力の半分が引かれているので、
守備力が1増えると基本ダメージは0.5減ります。
基本ダメージ = 攻撃力 - (守備力 / 2)
∂基本ダメージ/∂守備力 = -1/2 = -0.5
守備力が増えると基本ダメージが減るのでマイナスになります。
- **連鎖律の適用 **
ここで連鎖律を適用して最終的な影響の計算をします。
「守備力→基本ダメージ」の影響(-0.5)と
「基本ダメージ→最終ダメージ」の影響(0.5039)を掛け合わせます。
∂最終ダメージ/∂守備力 = 0.5039 × (-0.5) = -0.2520
結果は-0.2520となり、守備力が1増えると最終ダメージは約0.25減ることがわかります。
これは先ほど直接計算した結果と一致します。
攻撃力の場合も同様
同じ方法で、攻撃力が最終ダメージに与える影響も計算できます。
∂最終ダメージ/∂攻撃力 = ∂最終ダメージ/∂基本ダメージ × ∂基本ダメージ/∂攻撃力
= 0.5039 × 1
= 0.5039
攻撃力の場合、基本ダメージへの影響は1(攻撃力が1増えれば基本ダメージも1増える)なので、
最終的な影響は倍率と同じ0.5039になります。
つまり、攻撃力が1増えると最終ダメージは約0.5増えます。
連鎖律は有用
この例では計算が単純ですが、ニューラルネットワークのように
何層にも重なった計算では、連鎖律なしに偏微分を求めることはかなりコストが高くなります。
連鎖律により、どんなに深い層であっても、各層の偏微分を順番に掛けていくだけで、最終的な偏微分を求めることができます。
誤差逆伝播の実装イメージ
実際のプログラムでは、以下のような流れで実装されます。
# 順伝播
def forward(攻撃力, 守備力, 乱数):
基本ダメージ = 攻撃力 - (守備力 / 2)
倍率 = (99 + 乱数) / 256
ダメージ = 基本ダメージ * 倍率
return ダメージ
# 逆伝播
def backward(誤差, 攻撃力, 守備力, 乱数):
# 各パラメータの偏微分を計算
d攻撃力 = (99 + 乱数) / 256
d守備力 = -0.5 * (99 + 乱数) / 256
# 更新量を計算
攻撃力_更新量 = 学習率 * (-誤差) * d攻撃力
守備力_更新量 = 学習率 * (-誤差) * d守備力
return 攻撃力_更新量, 守備力_更新量
backwardでは連鎖律を使って偏微分を事前に計算済みです。
ニューラルネットワークへの応用
ちなみに、このダメージ計算は、ニューラルネットワークの1層分の計算と似ているらしいです。(Claudeさんが教えてくれました)
- 入力:攻撃力、守備力(ニューラルネットワークでは入力データ)
- 重み:計算式の係数(ニューラルネットワークでは学習可能な重み)
- 出力:ダメージ(ニューラルネットワークでは次の層への入力)
ニューラルネットワークでは、このような計算を何層にも重ねて、
より複雑な関数を表現します。
そして各層で同じように誤差逆伝播を適用することで、
全ての重みを効率的に更新しています。
Summary
今回はDQ3のダメージ計算を例にして
誤差逆伝播法を解説しました。
ポイントは以下です。
- 順伝播を理解 :まず通常の計算の流れを把握する
- 偏微分を理解:各パラメータが結果にどう影響するか
- 連鎖律を活用:複雑な計算も段階的に分解できる
- 更新の方向を理解:誤差を減らすためにパラメータをどう動かすか
誤差逆伝播法は一見難しく見えますが、
「出力の誤差を計算し、その誤差に各パラメータがどれだけ寄与しているかを偏微分で求め、
パラメータを誤差が減る方向に更新する」
というシンプルな処理です。
ニューラルネットワークでは、この同じ原理を何層にも重ねて適用することで、
複雑なパターンを学習していきます。
誤差逆伝播法は、ニューラルネットワーク・ディープラーニングを理解する上で
重要な基礎技術なので、しっかり理解したいところです。