Confidence Weightedを実装してみる

こんにちは、小澤です。

当エントリは「Machine Learning Advent Calendar 2017」の23日目のエントリです。

今回は、オンライン学習の手法であるConfidence Weighted(以下CW)を実装させていただきます。

CWとは

CWはオンラインの線形分類器です。 その特徴は、weightを単一の値で扱うのではなく、正規分布にしたがった分布で表します。

\[ {\bf w} \sim \mathcal{N}({\bf \mu}, \Sigma) \]

このように表現することで、現在のweightの値にどれだけ信頼がおけるか(Confidence)を持たせることができます。 これがこの手法の名前の由来になっています。 分散が大きいw_iはそれだけ自信がなく、小さいものはほぼその値で間違いないだろう的な感じです。

これによってデータ中に頻繁に登場する(0でないことが多い)特徴ほど分散が低く、あまり出てこない特徴(0であるデータが多い)ほど分散が大きくなります。

CWの更新式

CWのパラメータの更新式は、新たに入ってきたデータを基にした現在の分布に近いものを求めることで更新します。 分布の近さはKLダイバージェンスを最小化することで求めます。

制約条件として、正しい予測を行えるようにするための制約条件を設けます。 制約は正解ラベルは+1または-1なので、予測値の絶対値が0以上である確率をハイパーパラメータη以上とします。

\[ ({\bf \mu}_{i+1}, \Sigma_{i+1}) = \min D_{KL}( \mathcal{N}({\bf \mu}, \Sigma) || \mathcal{N}({\bf \mu}_i, \Sigma_i) ) \\ st. P(y_i{\bf w}{\bf x}_i) \geq 0) \geq \eta \]

途中の計算式は省略しますが、この条件を満たすμとΣの更新式は以下のようになります。

\[ {\bf \mu}_{i+1} = {\bf \mu}_i + \alpha y_i \Sigma_i{\bf x}_i \\ \Sigma_{i+1} = \Sigma_i - \Sigma_i{\bf x}_i \frac{2\alpha \phi}{1+2\alpha \phi {\bf x}_i^T \Sigma_i {\bf x}_i} {\bf x}_i^T \Sigma_i \]

ここで各変数は以下のようになっています。

\[ \alpha = \max \{ 0, \frac{ -(1+2 \phi M_i) + \sqrt{ (1+2 \phi M_i)^2 - 8 \phi (M_i - \phi V_i) } }{ 4 \phi V_i } \} \\ M_i = y_i({\bf x}_i {\bf \mu}_i) \\ V_i = {\bf x}_i^T \Sigma_i {\bf x}_i \\ \phi = {\bf \Phi}^{-1}(\eta) ({\bf \Phi}は標準正規分布の累積密度関数) \]

この式は、CWの出展である「Confidence-Weighted Linear Classification」に記載されているものですが、 明日紹介予定のSCWの論文中に記載されている変形を行ったものもよく見かけるのでそちらの形式でも記載します。

M_iをm_i, V_iをυ_iとして、

\[ \alpha_i = \max \{ 0, \frac{1}{\upsilon_i \zeta}(-m_i \psi + \sqrt{m_t^2 \frac{\phi^4}{4}+\upsilon_i \phi^2 \ zeta}) \} \\ \beta_i = \frac{\alpha_i \phi}{\sqrt{u_i} + \upsilon_i \alpha_i \phi} \\ \psi = 1+\frac{\phi^2}{2} \\ \zeta = 1 + \phi^2 \\ u_i = \frac{1}{4}(-\alpha_i \upsilon_i \phi + \sqrt{\alpha_i^2 \upsilon_t^2 \phi^2 + 4 \upsilon_t})^2 \]

の形式で表して更新式を以下のように記載します。

\[ {\bf \mu}_{i+1} = {\bf \mu}_i - \alpha_i y_i \Sigma_i {\bf x}_i \\ \Sigma_{i+1} = \Sigma_i - \beta_i \Sigma_i {\bf x}_i^T {\bf x}_i \Sigma_i \]

実装

さて、実装してみましょう。 今回もおなじみirisの一部を削って二値分類にしたものを利用します。 また、φの計算にはscipyを利用します。

まず必要なものをインポートします。 pandasは最終結果を見るためだけに使っています。

import numpy as np
import pandas as pd

from scipy.stats import norm
from sklearn.datasets import load_iris
from numpy.random import random

次にデータセットをロードします。

iris = load_iris()

# 二値分類をするので1つのラベル分取り除く
iris_data = iris.data[50:]
# ラベルを-1と+1にする
iris_target = iris.target[50:] *2 - 3

# オンライン学習はデータの順番に依存するのでシャッフルする
index = np.arange(len(iris_data))
np.random.shuffle(index)

# 学習データとテストデータに分割
iris_train_data =  iris_data[index[:80]]
iris_train_target = iris_target[index[:80]]

iris_test_data = iris_data[index[80:]]
iris_test_target = iris_target[index[80:]]

その後ハイパーパラメータのηと、そこから決まる定数を計算しておきます。 μとΣの初期値は、論文中に記載のある

\[ {\bf \mu}_1 = {\bf 0} \\ \Sigma_1 = aI \\ a > 0 \]

として、aの値は0-1の乱数で生成した値を使った対角行列を作成しています。

# ハイパーパラメータ
eta = 0.8

# 定数
phi = norm.cdf(eta) ** -1
psi = 1 + (phi ** 2 / 2)
zeta = 1 + phi ** 2

# 初期化
# μ_1 = 0, Σ_1 = aI (a > 0)より
mu = np.zeros(iris_train_data.shape[1])
sigma = np.diag([random()]*iris_train_data.shape[1])

いよいよもって、学習を開始します。 今回は学習に利用するデータが80件と少なかったため、3周くらい回すことで補っています。

# 学習
# データ件数が少なくて収束しなかったので3周くらいする
for _ in range(3):
    for x, y in zip(iris_train_data, iris_train_target):
        upsilon = np.dot(x, np.dot(sigma, x))
        m = y * np.dot(mu, x)

        gamma = 1 / (upsilon * zeta) * (-m * psi + np.sqrt(m**2 * (phi**4 / 4) + upsilon * phi**2 * zeta))
        alpha = max(0, gamma)

        u = (1 / 4) * (-alpha * upsilon * phi + np.sqrt(alpha**2 * upsilon**2 * phi**2 + 4 * upsilon))**2
        beta = alpha * phi / (np.sqrt(u) + upsilon * alpha * phi)

        mu = mu + alpha * y * np.dot(sigma, x)
        sigma = sigma - beta * np.dot(np.dot(sigma, x), np.dot(mu, sigma))

最後に予測を行ってみましょう。

# 予測値が0以上であれば+1, 未満であれば-1のラベル付けを行う
predict = [(np.dot(x, mu), 1 if np.dot(x, mu) >= 0 else -1, y) for x, y in zip(iris_test_data, iris_test_target)]
pd.DataFrame(predict, columns=['predict_value', 'predict_label', 'target'])

結果は以下のようになりました。

本来であれば、ここから評価を行いますが、ぱっと見学習してるっぽいのでこの記事の目的である実装すること自体はは達成したことにします。

おわりに

今回は、オンライン学習の手法であるCWを実装しました。

数式が多く出てきて怖い思いをした方もいるかもしれませんが、省略した途中の導出部分を除けばほぼ書いてある通りに実装するだけとなります。

明日は、CWの発展である『SCW』の予定です。 お楽しみに!

参考文献