この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
データアナリティクス事業本部の鈴木です。
XGBoostの学習・推論に使用できるAPIには、以下があります。
- Learning API
- Scikit-Learn API
Scikit-Learn APIはscikit-learnから使うことを想定したAPIです。 私は慣れるまではよく「あれ、どっちのAPI使ったらよかったっけ?」と分からなくなっていたので、イメージがつきやすいよう、Scikit-Learn APIをscikit-learnのパイプラインと組み合わせた例を書いてみました。
Python用の各APIの詳細は、以下をご確認ください。
やってみる
検証した環境
- コンテナ:jupyter/datascience-notebook
- scikit-learn:1.0
- XGBoost:1.5.1
XGBoostはコンテナにインストールされていなかったので、pipでインストールしました。
!pip install xgboost==1.5.1
データの準備
今回は、scikit-learnのdatasetsにある、wine datasetを利用します。
多クラス分類を目的としたデータセットで、178サンプルが格納されています。targetにはclass_0・class_1・class_2の3種類があります。
以下のようにして、PandasのDataFrameに変換しておきます。
import pandas as pd
from sklearn.datasets import load_wine
data = load_wine()
X = pd.DataFrame(data["data"], columns=data["feature_names"])
y = pd.DataFrame(data["target"], columns=["target"])
データは以下のようになります。
X.head()
y.head()
後で訓練・テスト用に使うため、データを分けておきます。
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
パイプラインの構築
今回は分類問題を解きたいので、APIはXGBClassifierを使います。 XGBClassifierの前にStandardScalerで標準化しておきます。XGBoostの場合、必ずしも標準化する必要はありませんが、パイプラインの例示と、ほかのモデルと比較しやすいようにという意図です。
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import xgboost as xgb
cls = xgb.XGBClassifier(
eval_metric="mlogloss",
use_label_encoder=False
)
pipe = Pipeline(steps=[("standerd_scaler", StandardScaler()),
("Classifier", cls)])
パイプラインは以下のようになります。
from sklearn import set_config
set_config(display='diagram')
pipe
訓練と推論の実施
Scikit-Learn APIはfit
やpredict
などsklearn.pipeline.Pipeline
から呼び出されるメソッドが実装されているので、scikit-learnのモデルをパイプラインに組み込んだ際と同じように扱うことができます。
# 訓練
pipe.fit(X_train,
y_train.values.ravel())
# 推論
pipe.predict(X_test)
# モデルのscoreで評価する
pipe.score(X_test, y_test)
## 0.9333333333333333
Cross Validationもsklearn.model_selection.cross_val_score
から実行できます。
from sklearn.model_selection import cross_val_score
scores = cross_val_score(pipe,
X_train,
y_train.values.ravel(),
cv=5,
scoring='accuracy')
np.mean(scores)
## 0.9475783475783477
最後に
scikit-learnをよく使っている方だとScikit-Learn APIの使い方もイメージがつきやすいと思いますが、そうでない方だと「どっちを使うといいんだろう?」となりやすいところかなと思っています。Scikit-Learn APIだとscikit-learnのpipelineモジュールの恩恵が受けられるので、興味がある方はぜひパイプラインと合わせて使ってみてください。