[初級編]LLMへ至る道~活性化関数ってなにをしているの?~[3日目]

2023.12.03

みなさんこんにちは!クルトンです。

前回のブログは損失関数についてでした。2乗和誤差や交差エントロピー誤差について、簡単にですがご紹介していきました。

今回は活性化関数について説明していきます!活性化関数ってなんだ?ってところなのですが、機械学習のモデルでよく出てくる用語ですので、しっかり説明していきます!

では早速、活性化関数について見ていきましょう。

活性化関数ってなんだ?その説明とイメージ

活性化関数とは、一言で言うと、機械学習の表現力を高めるための関数となります。 活性化関数でどのように表現力を高めるのかについては、実際の関数について見ていただければ理解できるかと思います。

ですが先に 「表現力を高める」 という部分について説明をしていきます。

表現力を高めるってなんだろう

表現力が高いというと、みなさんはどのような事を想像するでしょうか。 クリエイティブな人を思い浮かべたりする人もいるかもしれないですね。

今回ご説明する表現力というのは、コンピュータでの出力の結果に関するものです。

例えば、y = ax + b (a≠0)という一次関数があります。この関数が表すグラフが想像できるでしょうか?答えは簡単で、aという傾きを持つまっすぐな線になります。

road-to-llm-advent-calendar-2023-03-01

ここで以下の問題を考えてみましょう。

下図のように、赤と青の点が描かれた散布図が存在する。
出来るだけ、青のグループと赤のグループで分離するように線を引きたい。
どのようにして引けば良いか?

road-to-llm-advent-calendar-2023-03-02

人が描いて良いのであれば、波打つ線を引くだけで終わりそうですが、コンピュータではそうはいきません。適切な関数を用意して線を引く必要があります。ただし一次関数のようなまっすぐな線を用意しても上手く線を引けそうにもないです。

ここで上記のような問題で活躍するのが活性化関数です。

road-to-llm-advent-calendar-2023-03-03

活性化関数とは上記の問題で言うところの「適切な波線を引く」という表現力を手にいれるために使われているものです。より正確には、複雑な関数を作るために使われる関数です。 ここでは「機械が柔軟な表現力(理解力)を手にいれるために使われているものなんだなぁ」という事を押さえておいてもらえればOKです!

活性化関数はさまざまなものがあり、それぞれを適切に使い分ける必要がありますが、代表的なものを確認していきます。

活性化関数の具体例

以下のような活性化関数があります。

  • シグモイド関数
  • ソフトマックス関数(softmax function)
  • ReLU(正規化線形ユニット、Rectified Linear Unit)
  • GELU(ガウス誤差線形ユニット、Gaussian Error Linear Unit)

以降ではそれぞれの活性化関数について、数式から描かれるグラフをもとに「どういう値の時にどういう値を返すものなのか」というイメージを掴んでみましょう。グラフはPythonコードで出力します。(全てGoogle ColabのランタイムCPUで実行していきます。)

では順番に確認していきましょう。

シグモイド関数

まず1つ目はシグモイド関数と呼ばれるものです。以下の数式で表されるものです。

σ(x) = 1 1 + exp ( - x )

ここでexpと書かれているのはネイピア数eで、exp(-x)ではeの-x乗という意味になります。数式の定義からPythonコードを使って以下のようなコードを書きました。

import numpy as np
import matplotlib.pyplot as plt

# 数式の定義通りのシグモイド関数を定義
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
print(f"xの値:{x_values}")

y_values = sigmoid(x_values)

# グラフ描画
plt.plot(x_values, y_values, label='Sigmoid Function')
plt.title("Sigmoid Function Graph")
plt.xlabel('x value')
plt.ylabel('sigmoid(x)')
plt.legend()
plt.grid(True)
plt.show()

sigmoid関数に入れた値は以下のものです。

xの値:[-10.          -9.89949749  -9.79899497  -9.69849246  -9.59798995
  -9.49748744  -9.39698492  -9.29648241  -9.1959799   -9.09547739
  -8.99497487  -8.89447236  -8.79396985  -8.69346734  -8.59296482
  -8.49246231  -8.3919598   -8.29145729  -8.19095477  -8.09045226
  -7.98994975  -7.88944724  -7.78894472  -7.68844221  -7.5879397
  -7.48743719  -7.38693467  -7.28643216  -7.18592965  -7.08542714
  -6.98492462  -6.88442211  -6.7839196   -6.68341709  -6.58291457
  -6.48241206  -6.38190955  -6.28140704  -6.18090452  -6.08040201
  -5.9798995   -5.87939698  -5.77889447  -5.67839196  -5.57788945
  -5.47738693  -5.37688442  -5.27638191  -5.1758794   -5.07537688
  -4.97487437  -4.87437186  -4.77386935  -4.67336683  -4.57286432
  -4.47236181  -4.3718593   -4.27135678  -4.17085427  -4.07035176
  -3.96984925  -3.86934673  -3.76884422  -3.66834171  -3.5678392
  -3.46733668  -3.36683417  -3.26633166  -3.16582915  -3.06532663
  -2.96482412  -2.86432161  -2.7638191   -2.66331658  -2.56281407
  -2.46231156  -2.36180905  -2.26130653  -2.16080402  -2.06030151
  -1.95979899  -1.85929648  -1.75879397  -1.65829146  -1.55778894
  -1.45728643  -1.35678392  -1.25628141  -1.15577889  -1.05527638
  -0.95477387  -0.85427136  -0.75376884  -0.65326633  -0.55276382
  -0.45226131  -0.35175879  -0.25125628  -0.15075377  -0.05025126
   0.05025126   0.15075377   0.25125628   0.35175879   0.45226131
   0.55276382   0.65326633   0.75376884   0.85427136   0.95477387
   1.05527638   1.15577889   1.25628141   1.35678392   1.45728643
   1.55778894   1.65829146   1.75879397   1.85929648   1.95979899
   2.06030151   2.16080402   2.26130653   2.36180905   2.46231156
   2.56281407   2.66331658   2.7638191    2.86432161   2.96482412
   3.06532663   3.16582915   3.26633166   3.36683417   3.46733668
   3.5678392    3.66834171   3.76884422   3.86934673   3.96984925
   4.07035176   4.17085427   4.27135678   4.3718593    4.47236181
   4.57286432   4.67336683   4.77386935   4.87437186   4.97487437
   5.07537688   5.1758794    5.27638191   5.37688442   5.47738693
   5.57788945   5.67839196   5.77889447   5.87939698   5.9798995
   6.08040201   6.18090452   6.28140704   6.38190955   6.48241206
   6.58291457   6.68341709   6.7839196    6.88442211   6.98492462
   7.08542714   7.18592965   7.28643216   7.38693467   7.48743719
   7.5879397    7.68844221   7.78894472   7.88944724   7.98994975
   8.09045226   8.19095477   8.29145729   8.3919598    8.49246231
   8.59296482   8.69346734   8.79396985   8.89447236   8.99497487
   9.09547739   9.1959799    9.29648241   9.39698492   9.49748744
   9.59798995   9.69849246   9.79899497   9.89949749  10.        ]

上記の値をsigmoid関数に入れて描画したグラフは以下のようになります。

road-to-llm-advent-calendar-2023-03-04

滑らかな曲線が描かれている事が確認できます! 値を複数作った時にシグモイド関数を一度経由すると滑らかな曲線が作られるのが確認できました。

ソフトマックス関数(softmax function)

ソフトマックス関数の定義式は以下のようになります。

softmax ( x ) i = exp ( x i ) k = 1 n exp ( x k )

ソフトマックス関数ではΣが分母で使われており、Σで計算している要素の1つを分子に持ってきています。

これが何を意味するか分かりますか?とても重要な意味を持ちます。

正解は 「確率を表す」 です。

確率とは、要素の1つの値を、全ての要素の値を足したもので割ると求められますよね。例で言うと「くじ10本のうち当たりくじが2本あり、1度くじを引いた時に当たりが出る確率を求める」のような問題は当たりくじが出る2に対して全てのくじ本数である10を割る事で、0.2という確率が求められます。

ソフトマックス関数では入力された値を確率へ変換するという重要な機能があります。

確率なのでシグモイド関数で変換された後の全ての確率の値を足し算すると1になります。 なお、今後ご紹介する機械学習モデルでもソフトマックス関数は登場します。

ではグラフを確認していきましょう。

こちらもシグモイド関数に続いてexpと書かれている数式で、ネイピア数eを使ったグラフになります。

以下のPythonコードで確認しました。

import numpy as np
import matplotlib.pyplot as plt

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
y_values = softmax(x_values)

# グラフ描画
plt.plot(x_values, y_values, label='Softmax Function')
plt.title('Softmax Function Graph')
plt.xlabel('x value')
plt.ylabel('softmax(x)')
plt.legend()
plt.grid(True)
plt.show()

入力に使われているxの値はシグモイド関数と同じですので確認は省略します。 実行すると以下のようなグラフが出力されます。

road-to-llm-advent-calendar-2023-03-05

シグモイド関数とは大きく異なりますが、同じ点があります。 それはグラフが滑らかな曲線である事です。

ソフトマックス関数については、以下の点を押さえておくと良いと思われます。

  • 出力される値を繋げると滑らかな曲線になる
  • ソフトマックス関数を使うと確率へ変換できる

ReLU(正規化線形ユニット、Rectified Linear Unit)

次はReLUについてです。定義式は以下のようになります。

relu ( x ) = max ( 0 , x )

かなり簡単な式ですね。max(0, x)という事なので出力は以下のどちらかになります。

  • 0より大きい値ならばそのまま出力する
  • 0以下の数値ならば0を出力する

以下のPythonコードでグラフを確認してみます。

import numpy as np
import matplotlib.pyplot as plt

def relu(x):
    return np.maximum(0, x)

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
y_values = relu(x_values)

# グラフ描画
plt.plot(x_values, y_values, label='ReLU Function')
plt.title('ReLU Function')
plt.xlabel('x')
plt.ylabel('ReLU(x)')
plt.legend()
plt.grid(True)
plt.show()

出力したグラフは以下のようになります。

road-to-llm-advent-calendar-2023-03-06

0より小さい値の場合は0を出力し、0より大きい値の場合はそのまま出力できていそうですね。

ReLUの特徴として、以下のようなものがあります。

  • 計算がシグモイド関数に比べると簡単なので処理が高速
  • 0より小さい値は一律0としてしまうので、大きい値を際立たせる

GELU(ガウス誤差線形ユニット、Gaussian Error Linear Unit)

次はGELUについてグラフを確認していきます。実はGELUはReLU関数のような見た目の関数であり、ReLUの発展系の活性化関数です。この活性化関数も今後ご紹介するモデルで使われているのでここで確認しておきましょう。

では早速ですが以下の数式をご覧ください。

gelu ( x ) = x Φ ( x )

見た目は随分とシンプルですね。xとΦ(x)を掛け算しているのがGELUです。 「Φ(x)とはなんだ?」となりそうですが、標準正規分布の累積分布関数です。

標準正規分布?累積分布関数?

標準正規分布とは、統計学において特別なグラフを表しています。どういう事かというと、標準正規分布は確率を求める際に標準として使われている分布なのです。以下のPythonコードでグラフを確認できます。

import numpy as np
import matplotlib.pyplot as plt

# 標準正規分布の確率密度関数(PDF)を計算
def standard_normal_distribution(x):
    return np.exp(-x**2 / 2) / np.sqrt(2 * np.pi)

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
y_values = standard_normal_distribution(x_values)

# グラフ描画
plt.plot(x_values, y_values, label='Standard Normal Distribution')
plt.title('Standard Normal Distribution Graph')
plt.xlabel('x value')
plt.ylabel('Probability Density')
plt.legend()
plt.grid(True)
plt.show()

road-to-llm-advent-calendar-2023-03-07

グラフの縦軸は確率密度です。確率密度は簡単にいうと、値が大きいほど起こりやすい事を表しています(つまりy軸の値が大きいほど確率が高いです)。このグラフは簡単にいうと、横軸の値(x value)が0に近ければ近いほど確率が高い、というものになります。

次に累積分布関数とは、グラフ上のある点までの面積がその点までの確率の合計を表したもののことです。したがって、全ての点を含む面積を求めると値は1になります。(全ての確率を足し合わせた事と同じためです。)

なお、面積を求めるには積分が必要なのですが、ここではイメージを掴みたいのでグラフにしてどういうものか簡単に確認してみましょう。

以下のPythonコードを実行すると、標準正規分布の累積分布関数のグラフが出力されます。

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

# 標準正規分布の累積分布関数を計算
def standard_normal_cdf(x):
    return norm.cdf(x)

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
y_values = standard_normal_cdf(x_values)

# グラフ描画
plt.plot(x_values, y_values, label='Standard Normal CDF')
plt.title('Standard Normal Cumulative Distribution Function Graph')
plt.xlabel('x value')
plt.ylabel('Cumulative Probability')
plt.legend()
plt.grid(True)
plt.show()

出力されるグラフは以下のようになります。滑らかな曲線ですね。

road-to-llm-advent-calendar-2023-03-08

この累積分布関数で出力される値を掛け算するとGELUの値が求まるという事です。

GELUのグラフを出力してみる

PythonコードでGELUのグラフを出力してみます。なお、GELUの数式が上で説明した数式と見た目が異なるのは、以下のGELUに関する論文をもとにしています。(上で説明した数式とほぼ同じ値を返してくれるものです。)

では、以下のPythonコードを使ってグラフを出力してみましょう。

import numpy as np
import matplotlib.pyplot as plt

def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
y_values = gelu(x_values)


# グラフ描画
plt.plot(x_values, y_values, label='GELU Function')
plt.title('GELU Function Graph')
plt.xlabel('x value')
plt.ylabel('GELU(x)')
plt.legend()
plt.grid(True)
plt.show()

road-to-llm-advent-calendar-2023-03-09

見た目は確かにReLUのような形ですね。どう違うのかついでに確認してしまいましょう。

GELUと一緒にReLUのグラフも以下のPythonコードを使って出力してみました。

import numpy as np
import matplotlib.pyplot as plt

def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

def relu(x):
    return np.maximum(0, x)

x_values = np.linspace(-10, 10, 200) # -10から10の間で等間隔に200個の値を作る
y_values_gelu = gelu(x_values)
y_values_relu = relu(x_values)


# グラフ描画
plt.plot(x_values, y_values_gelu, label='GELU Function')
plt.plot(x_values, y_values_relu, label='ReLU Function')
plt.title('GELU Function Graph')
plt.xlabel('x value')
plt.ylabel('GELU(x) or ReLU(x)')
plt.legend()
plt.grid(True)
plt.show()

road-to-llm-advent-calendar-2023-03-10

ほとんど同じなのか、重なっていますね。ReLUに近い形なのでReLUの特徴も引き継いでいそうです。

終わりに

ここまで読んでいただきありがとうございました!

今回は活性化関数について確認をしていきました。複数の具体例もご紹介しましたが、あくまでも「学習させているモデルが柔軟な理解力を手に入れるために使っているものなんだなぁ」と理解していただければOKです。

明日はニューラルネットという機械学習モデルの中で使われている「構造」についてのブログを公開します。前日と本日紹介した、損失関数と活性化関数が登場してきます。

本日はここまで。よければ明日のブログもチェックお願いします!

参考文献