機械学習を解釈するって何だろう?そんなことを考えながらSHAPに想いを馳せてみる
おはこんハロチャオ~!なにもんなんじゃ?じょんすみすです。
当エントリー『クラスメソッド 機械学習チーム アドベントカレンダー 2022』は25日目となります。
アドベントカレンダーもついに最終日となり、クリスマスや年末に想いを馳せている人も多いかと思いますが、 私はどうすればAIの気持ちを理解できるかと想いを馳せておりました。
彼らの気持ちを理解するために今日はSHAP(SHapley Additive exPlanations)のことを考えていきたいと思います。
SHAPが何をしているのかについて
人もAIも、多くのことをそのまま覚えておくのはあまり好きでありません。 何もかもすべてをそのまま覚えておいて使うのではなく、全体的な傾向としての特徴を把握しておきます(え?k-Nearest Neighborの話はしてないよ)。
考え方
人が平均値でものごとを見て"普通"の基準とするのと同様、AIだって同じことを考えたいです。 そうです、平均的な特徴量から生まれる平均的な結果...そう、AIが出す結果の期待値から特徴量がどう変われば結果もどう変わるのか、からその特徴量が結果にどの程度影響を与えたのかを考えるのがSHAPです。
例として、HP(H), 攻撃(A), 防御(B), 特攻(C), 特防(D), 素早さ(S)の6つの特徴量があるとしてそれぞれの値からそのポケモンをパーティメンバーとして採用するかを考えてみたいと思います。
ある機械学習システムではこの6つの特徴量を与えると、そのポケモンを採用すべき確率が0~1の範囲で出力されます。 これに対して、なぜその判断をしたのかを考えてみたいと思います。
まず、ベースラインとして全てのポケモンの採用率を0.5(50%)とします。 これに対してあるポケモンが採用する確率が30%となった場合にその特徴として以下のような状況だったとします。
- Aがとても高いの
- Cが高め
- Sが低い
としたときに、攻撃特化の速攻アタッカーにしたいときにはCは低くていいのでSを高くしてほしかったという要望が生まれるかもしれません。そのため
- Aの高さで0.3上昇
- Cの高さで0.1下降
- Sの低さで0.4下降
のような、 0.5 + 0.3 - 0.1 - 0.4 = 0.3(30%)
と各特徴がどの程度その結果に影響を与えたのかを考えます。
同様にほかの80%採用されるポケモンがいた場合、
- Hが高いので0.1上昇
- Bが高いので0.1上昇
- Dが高いので0.1上昇
0.5 + 0.1 + 0.1 + 0.1 = 0.8(80%)
のように考えます。
このようにそのポケモンを採用するかをポケモンの種類ごと(=予測したいターゲットごと)に、その推論結果に対して、ベースラインからその特徴量が与えた影響の内訳を考えるのがSHAPの基本的な考え方となります。 こうすることで、AIが推論した結果に対しての「なぜ」その結果を出したかの説明が可能になるという訳です。
分解の仕方に必要なシャープレイ値
各特徴がどれだけ影響を与えているのかに分解する方法を考えていきます。 SHAPではその名称の元にもなっているShapleyというものを使います。
これは、協力ゲーム理論におけるシャープレイ値というものをベースにしたものです。 "協力"の名の通り、複数のプレイヤー同士が協力することで得られる結果に影響があるような状況を想定します。
例題で見てみましょう。 今、Aさん, Bさん, Cさんの3人である仕事をするのに以下のような状況だとします。
担当者 | 出来上がる進捗 |
---|---|
Aさん | 5 |
Bさん | 3 |
Cさん | 2 |
Aさん・Bさん | 10 |
Aさん・Cさん | 8 |
Bさん・Cさん | 7 |
Aさん・Bさん・Cさん | 15 |
複数人で協力した際に、個々の足し算よりも大きな成果が出ていることが分かります。 この時、複数人で協力した際の各人の貢献度はどのようになるでしょうか? この計算をするために限界貢献度という概念を使います。 仕事をするに際して、Aさんが新たに加わった場合にどの程度進捗が増えるを全てのパターンで考えます。
- Aさんのみで仕事をする
- 進捗が0 → 5に増える
- 5 - 0 = 5がAさんの限界貢献度
- Bさんがいる状態でAさんも仕事をする
- 進捗が3 → 10に増える
- 10 - 3 = 7がAさんの限界貢献度
- Cさんがいる状態でAさんも仕事をする
- 進捗が2 → 8に増える
- 8 - 2 = 6がAさんの限界貢献度
- Bさん・Cさんがいる状態でAさんも仕事をする : 進捗が7→15に増える
- 15 - 7 = 8がAさんの限界貢献度
これをBさん・Cさんに対しても同様に計算することができます。
また、これはAさん作業に加わる順番に依存して限界貢献度が変わるということを意味してみます。 3人の場合、それぞれがどの順番で関わるかは3! = 6通りのパターンに対して各人の限界貢献度が計算できます。
参加準 | Aさんの限界貢献度 | Bさんの限界貢献度 | Cさんの限界貢献度 |
---|---|---|---|
A → B → C | 5 | 5 | 5 |
A → C → B | 5 | 5 | 3 |
B → A → C | 7 | 3 | 5 |
B → C → A | 8 | 3 | 4 |
C → A → B | 6 | 7 | 2 |
C → B → A | 7 | 5 | 2 |
この6パターンそれぞれの各人の限界貢献度の平均を取ることで求められるのがシャープレイ値です。 Aさんの場合 (5 + 5 + 7 + 8 + 6 + 7) / 6 ≒ 6.33 となります。
これを一般化して、 n
人参加における i
さんのシャープレイ値 φ
は以下のような式で表されます。
ここで N\{i}
は空集合を含むプレーヤーの集合の組み合わせ、 S
はその中の個々の集合です。
|S|
は集合 S
に含まれる要素数を表しています。
ν(・)
は報酬を表す関数となり、 プレーヤー i
を含む ν(S ∪ {i})
から含まない ν(S)
の値を引いたものが限界貢献度を表しています。
|S|!(n - |S| -1)!
は組み合わせの数を表しており、先ほどのAさんの例であれば、 S ={B, C}
に対してBさんが先のパターンとCさんが先のパターンの2種類となります。
特徴量とシャープレイ値
SHAPでは特徴量の貢献度を計算するのに、各特徴量をプレーヤーとしたシャープレイ値を計算すればいいだろうという発想になります。 シャープレイ値を計算することで、ある特徴量が加わったことにより推論結果の変化がどの程度あったかとしています。
この計算で良く利用されるパターンとして、各特徴量の期待値をあらかじめ計算しておきます。 その値を使って確率の周辺化を行い、特徴量x_iのみを推論対象として与えられたものに変化させたときの推論結果の変化を見ます。 これをx_1, x_2, ..., x_nと同様にやっていくことで、各特徴量の変化に対する推論結果の変化が計算できるわけです。
なお、期待値を入れたものを変化させる順をx_1 → x_2とした場合と、x_2 → x_1とした場合とでx_1を与えた時の変化量に差が生じます。 そのため、各特徴量に対して全ての順番で与えた際の平均値を取った値をSHAP値として利用します。
それを全ての特徴量に対して行うと組み合わせの数だけ計算が必要となるため、 一般的なライブラリでは近似的な計算を行う場合もあります。
Pythonでの実装
SHAPの仕組みが分かったところで具体的な実装をしてみましょう。
と、言いたいところですが、Pythonでは shap
パッケージ、Rでは DALEX
パッケージを使うなどすれば利用できます。
今回は、Pythonの shap
パッケージを使って動かしてみます。
まずは、対象となるモデルの作成を行います。
import shap from sklearn.datasets import fetch_california_housing from sklearn.ensemble import RandomForestRegressor # カリフォルニアの住宅価格予測用データセットを取得 housing = fetch_california_housing(as_frame=True) # Random Forestのモデルを作成 # 簡単のため、ハイパーパラメータの設定とデータの分割は省略している reg = RandomForestRegressor() reg.fit(housing['data'], housing['target'])
次に、決定木ベースの手法用のモデルのSHAP値を計算するTreeExplainerを使って計算する。
explainer = shap.TreeExplainer(model=reg, data=housing['data']) # データ量が多いと計算に時間がかかるため、100件に絞って算出 shap_values = explainer(housing['data'].head(n=100))
これでSHAP値が求められました。 いくつかの情報をグラフ化してみましょう。
まずは、単一のデータに対する特徴量の貢献度を見てみます。
shap.plots.waterfall(shap_values[0])
これは単一の推論対象に対する貢献度なので、 shape_values
の別な添え字のデータに対して行えば同様に他の予測結果に対する貢献度が確認できます。
複数のデータに対して貢献度を取得している場合、それらの統計的な性質を求めることもできます。 まずは、各データの貢献度の絶対値の平均を取って変数重要度を見ていましょう(貢献度はプラス, マイナスの双方があるので打ち消し合わないように絶対値を取っています)。
shap.plots.bar(shap_values=shap_values)
また、平均値として取得するのではなく、分布をみることでどちらの方向に貢献しているケースが多いのかも合わせて確認できます。
shap.plots.beeswarm(shap_values)
このように shap
ライブラリを利用すると比較的容易にSHAP値の取得やその結果に基づくプロットが可能です。
おわりに
今回は、AIの気持ちになってみようと思ってExplainable AIを実現するための手法の一つであるSHAPを見てみました。
SHAPは様々なモデルで利用可能な手法です。 そのため、推論結果の理由の説明が必要になった際には利用しやすいですね。 一方、組み合わせの計算が必要となるため非常に重い処理であるということも注意点として覚えておくといいでしょう。