causalmlでT-Learnerを試してみた

2020.04.14

最初に

「T-Learner」への理解を深めるために、causalmlで用意されているサンプルを試しつつ、中身を確認しました。
「T-Learner」とはなんぞや、という方はこちらもご参照ください。

今回参考にしたサンプルは下記の2つです。

また、検証時のバージョンは「0.7.0」です。

import causalml
print(causalml.__version__)
'0.7.0'

目次

1.やることの概要

作成したデータに対して「T-Learner」を適用して「推論」→「結果の可視化」を実施して利用のイメージを掴みます。
データはcausalmlで用意されている関数を使って作成しています。

  • 使うデータの概要
    従属変数と、共変量(8列)と介入効果を示す列、を利用します。
y ((n,)-array): outcome variable.
X ((n,p)-ndarray): independent variables.
treatment ((n,)-array): treatment flag with value 0 or 1.
  • データの中身
print('='*20+'y'+'='*20)
print(y[:10])
print('='*20+'X'+'='*20)
print(X[:10])
print('='*20+'treatment(w)'+'='*20)
print(treatment[:10])
====================y====================
[ 1.00829355  0.05922987  2.11008642  2.39901112  1.38812442  0.18220695
  0.87736274  1.82545868  2.3499108  -0.38969251]
====================X====================
[[0.55856514 0.18356411 0.87890611 0.1045858  0.53001749 0.81463171
  0.8277885  0.49984797]
 [0.44373612 0.03647427 0.18657762 0.08206453 0.13997463 0.79560809
  0.13194341 0.12079058]
 [0.77570475 0.61996097 0.5380459  0.54915099 0.79648212 0.35820395
  0.6524747  0.96421018]
 [0.2814506  0.03344552 0.09795269 0.99106744 0.86448333 0.27776662
  0.08409478 0.25527859]
 [0.57646013 0.17946275 0.46069825 0.99452343 0.95181924 0.15847775
  0.53460512 0.63080691]
 [0.31849994 0.27292749 0.21019924 0.45968885 0.79176603 0.57486293
  0.05509886 0.89371108]
 [0.84660787 0.06497323 0.40103457 0.70353424 0.37864887 0.26828429
  0.82650411 0.3856174 ]
 [0.20097676 0.4382642  0.29526698 0.28513476 0.45423294 0.11343865
  0.85323826 0.47716232]
 [0.61415854 0.56282748 0.91196283 0.93761138 0.79623375 0.25337298
  0.88095175 0.25281966]
 [0.89061062 0.0444367  0.41399132 0.25181316 0.91828317 0.66352316
  0.69310375 0.0446523 ]]
====================treatment(w)====================
[0 0 1 0 0 0 1 1 1 1]

2.T-Learnerで推論する

causalml側で用意されている「xgboost」、「多層パーセプトロン」を利用することでお手軽にCATE/ITEを推論することができます。
パラメータを指定することで、信頼区間も調整できるのが嬉しいですね。

from causalml.inference.meta import XGBTRegressor,MLPTRegressor


# XGB
print('causalmlで用意されているXGBを利用')
learner_t = XGBTRegressor(ate_alpha=0.05)
cate_t = learner_t.estimate_ate(X=X, 
                               treatment=treatment,
                               y=y
                              )
print('信頼区間\n平均:{0}\n下限:{1}\n上限:{2}'.format(cate_t[0][0],cate_t[1][0],cate_t[2][0]))


# 多層パーセプトロン
print('\ncausalmlで用意されている多層パーセプトロンを利用')
learner_t = MLPTRegressor( hidden_layer_sizes=(100,)
                          ,activation='relu'
                          ,solver='adam'
                          ,batch_size='auto'
                          ,learning_rate='constant'
                          ,learning_rate_init=0.001
                          ,max_iter=200
                          ,random_state=77
                          ,ate_alpha=0.05
                         )
cate_t = learner_t.estimate_ate(X=X, treatment=treatment, y=y)
print('信頼区間\n平均:{0}\n下限:{1}\n上限:{2}'.format(cate_t[0][0],cate_t[1][0],cate_t[2][0]))
causalmlで用意されているXGBを利用
信頼区間
平均:0.49939890674948695
下限:0.47411358700186157
上限:0.5246842264971123

causalmlで用意されている多層パーセプトロンを利用
信頼区間
平均:0.6113857429851489
下限:0.5720949207798465
上限:0.6506765651904514

モデルを指定して推論することも可能です。
(どんなモデルでも利用できるわけではないです)

from causalml.inference.meta import BaseTRegressor
from sklearn.linear_model import LinearRegression
from xgboost import XGBRegressor
import lightgbm as lgb

# 線形回帰
print('★線形回帰')
learner_t = BaseTRegressor(learner=LinearRegression()
                          ,ate_alpha=0.05
                          )
cate_t = learner_t.estimate_ate(X=X, treatment=treatment, y=y)
print('信頼区間\n平均:{0}\n下限:{1}\n上限:{2}'.format(cate_t[0][0],cate_t[1][0],cate_t[2][0]))


# XGB
print('\n★XGB')
learner_t = BaseTRegressor(learner=XGBRegressor()
                           ,ate_alpha=0.05
                          )
cate_t = learner_t.estimate_ate(X=X, treatment=treatment, y=y)
print('信頼区間\n平均:{0}\n下限:{1}\n上限:{2}'.format(cate_t[0][0],cate_t[1][0],cate_t[2][0]))


# LightGBM
print('\n★LightGBM')
learner_t = BaseTRegressor(learner=lgb.LGBMRegressor()
                           ,ate_alpha=0.05
                           )
cate_t = learner_t.estimate_ate(X=X, treatment=treatment, y=y)
print('信頼区間\n平均:{0}\n下限:{1}\n上限:{2}'.format(cate_t[0][0],cate_t[1][0],cate_t[2][0]))
★線形回帰
信頼区間
平均:0.669679905126833
下限:0.6291256839907688
上限:0.7102341262628971

★XGB
信頼区間
平均:0.49939890674948695
下限:0.47411358700186157
上限:0.5246842264971123

★LightGBM
信頼区間
平均:0.5173225980970793
下限:0.4862124246376465
上限:0.5484327715565119

ITE(Individual Treatment Effect)も推定できます。
このスコアに基づいて施策を実施する/しないを決定することもできそうです。

# ITEの計算
cate_t = learner_t.fit_predict(X=X, treatment=treatment, y=y)
cate_t
array([[0.74795519],
       [0.80700771],
       [0.91558024],
       ...,
       [1.40732371],
       [1.64977301],
       [1.10887185]])

3.可視化

続いて、推論結果の可視化をしてみます。
まずは、ITEの可視化をします。

alpha=0.2
bins=30
plt.figure(figsize=(12,8))
plt.hist(cate_t, alpha=alpha, bins=bins, label='T Learner')
plt.title('Distribution of CATE Predictions by Meta Learner')
plt.xlabel('Individual Treatment Effect (ITE/CATE)')
plt.ylabel('# of Samples')
_=plt.legend()

単純にプロットしただけですが、これだけ見るとある程度バランス良くスコアが推論されているようです。
続いて、特徴量の重要度を見てみましょう。

feature_names = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6','col7', 'col8']
cate_t = learner_t.fit_predict(X=X, treatment=treatment, y=y)
learner_t.get_importance(X=X, # 共変量
                        tau=cate_t, # ITE
                        normalize=True, 
                        method='auto', 
                        features=feature_names
                        )
{1: col2     0.349076
 col1     0.227213
 col4     0.097103
 col5     0.081013
 col3     0.069834
 col8     0.063978
 col6     0.060504
 col7    0.051280
 dtype: float64}

簡単ですね。
ここでは「tau」引数に渡した値を推論するモデルを作成し、このモデルと「method」引数に指定した値に応じて計算が行われているようです。

Builds a model (using X to predict estimated/actual tau), and then calculates feature importances based on a specified method.

参照:tlearner.py

method引数にどのような値が指定できるのかを見てみましょう。
デフォルト値は「auto」ですので、特に指定しない場合はLightGBMを使った結果を取得できるようです。

Currently supported methods are:
    - auto (calculates importance based on estimator's default implementation of feature importance;
            estimator must be tree-based)
            Note: if none provided, it uses lightgbm's LGBMRegressor as estimator, and "gain" as
            importance type
    - permutation (calculates importance based on mean decrease in accuracy when a feature column is permuted;
                   estimator can be any form)
Hint: for permutation, downsample data for better performance especially if X.shape[1] is large

参照:tlearner.py

折角なので、「permutation importance」も確認してみます。

# 「permutation importance」の計測値を確認
feature_names = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6','col7', 'col8']
cate_t = learner_t.fit_predict(X=X, treatment=treatment, y=y)
learner_t.get_importance(X=X, # 共変量
                        tau=cate_t, # ITE
                        normalize=True, 
                        method='permutation', 
                        features=feature_names
                        )
{1: col2    0.525061
 col1    0.330783
 col4    0.132314
 col5    0.107923
 col3    0.091707
 col8    0.064523
 col6    0.055205
 col7    0.050609
 dtype: float64}

今回、アルゴリズムは違うもののいずれにしても特徴量の重要さの順序、という点では同じ結果が取得できました。 また、下記のように可視化することも可能です。

learner_t.plot_importance(X=X, 
                         tau=cate_t, 
                         normalize=True, 
                         method='auto', 
                         features=feature_names
                         )

また、SHAP値も取得することができます。

shap_tlearner = learner_t.get_shap_values(X=X, 
                                          tau=cate_t
                                         )
print('='*20+'shap値の確認'+'='*20)
print(shap_tlearner)
print('='*20+'shap値の可視化'+'='*20)
learner_t.plot_shap_values(X=X, 
                           tau=cate_t, 
                           features=feature_names
                          )

4.最後に

以上で、「T-Learner」を使ってITE/CATEを求めて結果を可視化するところまでできました。
検証についても着手したかったのですが、介入効果の検証については実験用データでないと正しい値がわからないため、一般的な機械学習のフレームワークを適用するだけでは不十分で、色々試行錯誤(ex.複数のモデル間でITEの予測値にどの程度バラツキがあるか)して進める必要がありそうです。
(今回利用しているデータは検証用に作成しているデータなのでできるのですが、今回は割愛します)

検証方法については「データに対するドメイン知識」が重要そうですね。
(データ分析や機械学習でも、当然ドメイン知識はかなり重要ですが)
検証方法については、こちらにも記載されているので、よかったらご参照ください。

5.参照