TensorFlowのタイタニックデータセットを使って爆速で学習が完了するTabPFNを動かしてみた

今年arXivに収録された論文で、テーブルデータを超高速に学習できるTabPFNというモデルが登場したので試してみました!
2022.11.09

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

皆さん、こんにちは。クルトンです。

今回はテーブルデータを高速で学習できるTabPFNを試してみました。

最大1024個のデータを渡す事で学習が完了します。 こちらが論文となります。

なお、動作環境はGoogleColabでランタイムはCPUです。

データを用意

データ処理に必要なモジュールをimport

import tensorflow_datasets as tfds
import numpy as np
import pandas as pd

データのロードと中身をチェック

今回はタイタニック号のデータセットを使います。こちらは乗船客が生き残ったかどうかを予測するためのデータセットとなります。

tf_dataset, info = tfds.load("Titanic", as_supervised=True, with_info=True)

データの構造をチェックします。

print(info)
tfds.core.DatasetInfo(
    name='titanic',
    full_name='titanic/4.0.0',
    description="""
    Dataset describing the survival status of individual passengers on the Titanic. Missing values in the original dataset are represented using ?. Float and int missing values are replaced with -1, string missing values are replaced with 'Unknown'.
    """,
    homepage='https://www.openml.org/d/40945',
    data_path='~/tensorflow_datasets/titanic/4.0.0',
    file_format=tfrecord,
    download_size=114.98 KiB,
    dataset_size=382.58 KiB,
    features=FeaturesDict({
        'age': tf.float32,
        'boat': tf.string,
        'body': tf.int32,
        'cabin': tf.string,
        'embarked': ClassLabel(shape=(), dtype=tf.int64, num_classes=4),
        'fare': tf.float32,
        'home.dest': tf.string,
        'name': tf.string,
        'parch': tf.int32,
        'pclass': ClassLabel(shape=(), dtype=tf.int64, num_classes=3),
        'sex': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
        'sibsp': tf.int32,
        'survived': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
        'ticket': tf.string,
    }),
    supervised_keys=({'age': 'age', 'boat': 'boat', 'body': 'body', 'cabin': 'cabin', 'embarked': 'embarked', 'fare': 'fare', 'home.dest': 'home.dest', 'name': 'name', 'parch': 'parch', 'pclass': 'pclass', 'sex': 'sex', 'sibsp': 'sibsp', 'ticket': 'ticket'}, 'survived'),
    disable_shuffling=False,
    splits={
        'train': <SplitInfo num_examples=1309, num_shards=1>,
    },
    citation="""@ONLINE {titanic,
    author = "Frank E. Harrell Jr., Thomas Cason",
    title  = "Titanic dataset",
    month  = "oct",
    year   = "2017",
    url    = "https://www.openml.org/d/40945"
    }""",
)

なるほど、features=FeaturesDict({でデータの型を確認してみるとint型やfloat型、str型など複数の型が混じっているようです。

データセットをnumpy.ndarray型に変換します。

dataset=tfds.as_numpy(tf_dataset['train'])

クラス(生き残ったかを表すsurvived)とそれ以外のデータで保持する変数をわけます。

feature_data, label_data= np.array([list(x.values()) for x,y in dataset]), np.array([y for x,y in dataset])

以前行なったTensorFlowデータセットを使ったデータの準備では、ここまでのコードを実行すると準備が大体完了していたのですが、今回は注意が必要です。

次のコードを実行してみてください。

print(feature_data)

以下のような結果になるかと思います。

[[b'30.0' b'Unknown' b'-1' ... b'0' b'0' b'233478']
 [b'37.0' b'Unknown' b'98' ... b'0' b'2' b'3101276']
 [b'28.0' b'9' b'-1' ... b'1' b'0' b'230434']
 ...
 [b'-1.0' b'Unknown' b'-1' ... b'1' b'0' b'330935']
 [b'31.0' b'C D' b'-1' ... b'1' b'1' b'363291']
 [b'36.0' b'Unknown' b'-1' ... b'0' b'1' b'C.A. 34651']]

それぞれの数値や文字列にbの文字が入っています。

ここで、中身をより詳しく見るため、次のコードを実行してみてください。

for x in feature_data[0]:
  print(type(x), x)

以下のような結果が出力されるかと思います。

<class 'numpy.bytes_'> b'30.0'
<class 'numpy.bytes_'> b'Unknown'
<class 'numpy.bytes_'> b'-1'
<class 'numpy.bytes_'> b'Unknown'
<class 'numpy.bytes_'> b'2'
<class 'numpy.bytes_'> b'13.0'
<class 'numpy.bytes_'> b'Sarnia, ON'
<class 'numpy.bytes_'> b'McCrie, Mr. James Matthew'
<class 'numpy.bytes_'> b'0'
<class 'numpy.bytes_'> b'1'
<class 'numpy.bytes_'> b'0'
<class 'numpy.bytes_'> b'0'
<class 'numpy.bytes_'> b'233478'

このように、numpy.bytes_型となっているのですが、TabPFNはスカラー値しか受け取りません。仮にこのデータを使って学習を行うと次のようなエラーが出ます。

ValueError: Unable to convert array of bytes/strings into decimal numbers with dtype='numeric'

よって、データの型を数値であるint型やfloat型に変換する必要があります。

データを数値型に変換

info変数やタイタニックのデータセットについて説明している公式サイトを見ていただくと分かる通り、それぞれのデータがどういう型のものを想定しているかが分かります。

前準備

前準備として次のようなコードを実行します。

feature_data = pd.DataFrame(feature_data)
cast_data_types=[float,str,int,str,int,float,str,str,int,int,int,int,str]

ageやboatなどの列ごとに処理する必要があるためDataFrame型に変換し、それぞれの型の情報をcast_data_types変数に保持しておきます。

numpy.bytes_型から本来想定されていた型への変換

numpy.bytes_型は、decodeメソッドを使う事でstr型へ変換できます。それぞれのデータの要素毎にdecodeメソッドを実行するため、map関数とラムダ式を使って処理し、その後、それぞれの本来想定されているデータの型へcast_data_types変数を使って変換します。

for index in range(len(feature_data.columns)):
  feature_data[index] = pd.Series(map(lambda x: x.decode('utf-8'), feature_data[index]))
  feature_data[index] = feature_data[index].astype(cast_data_types[index])

ここは注意が必要です。DataFrame型ではastypeメソッドを使って変換する事で全てのカラムに対して一斉に変換が出来ます。

しかし、例えばstr型で変換した場合、feature_data[0][0]の要素は30.0でなくb'30.0'と、先頭にbが付いた文字列になってしまいます。 int型で変換しようとしても、人の名前などのstr型になるものは変換ができないのでエラーが出てしまいます。

decodeメソッドを使うと、bという文字は取り除かれた状態でstr型の値が返ってきますので、今回はこのような処理を採用しました。

次に変換が必要なのは、本来str型としてデータが格納されているカラムの処理です。

カテゴリカルデータを数値へ変換

「男性」と「女性」のようにアンケートで選択式になるようなデータはカテゴリカルデータと言います。 今回、そのようなデータを処理するcategory_encodersというモジュールを使用して、str型のデータを数値の型に変換します。

まず必要なモジュールをインストールとimportします。

!pip install category_encoders
import category_encoders as ce

次に変換するために必要なOrdinalEncoderクラスのインスタンスを作ります。

enc = ce.OrdinalEncoder()

fit_transformメソッドを変換したい対象(カラムがstr型になっている箇所)を渡すと数値に変換した後のデータが返ってくるので、feature_data変数それぞれのカラムごとに格納していきます。

tmp_df = enc.fit_transform(feature_data.iloc[:, [i for i, x in enumerate(cast_data_types ) if str == x]])
for index in tmp_df:
  feature_data[index]=tmp_df[index]

以下の処理をしなくとも次以降のコードは動きますが、公式GitHubのチュートリアルで扱っているデータの型と揃えるためnumpy.ndarray型に変換しています。

feature_data = feature_data.values

ここまでで、タイタニックデータセットを使う準備ができました。次からはTabPFNを動かしてみます。

TabPFNを動かす

必要なモジュールをインストールとimport

!pip install tabpfn
from sklearn.model_selection import train_test_split
from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier
from sklearn.metrics import accuracy_score
import time

TabPFNのモデル、正解率、予測にかかった時間を調べるために必要なモジュールをimportしています。

モデルのインスタンスを作成

インスタンスを生成します。

classifier = TabPFNClassifier()

次からモデルの学習データとテストデータに分割してからモデルの学習を行なっていきます。 ここで注意が必要で、冒頭で書いたように多くのデータを渡して学習するモデルではないため、1024個を超えるようなデータを渡すと次のようなエラーが出ます。

ValueError: ⚠️ WARNING: TabPFN is not made for datasets with a trainingsize > 1024. Prediction might take a while, be less reliable. We advise not to run datasets > 10k samples, which might lead to your machine crashing (due to quadratic memory scaling of TabPFN). Please confirm you want to run by passing overwrite_warning=True to the fit function.

したがって、変則的ですが、学習データよりもテストデータの方がデータ数を多くなるように調整してからモデルを学習させていきます。

学習データを全データのうち10%(200未満のサンプル数)に設定

学習データを全データのうち10%ほどに設定し、サンプル数を確認します。

x_train, x_test, y_train, y_test = train_test_split(feature_data, label_data, test_size=0.9)
print(np.shape(x_train))

学習データは130個でした。

(130, 13)

学習と予測

次の4行のコードを実行すると、学習と予測を行ない、それに掛かった時間とaccuracyを表示します。

start = time.time()
classifier.fit(x_train, y_train)
y_pred = classifier.predict(x_test)
print('学習後に予測に掛かった時間: ', time.time() - start, ', accuracy: ', accuracy_score(y_test, y_pred))

出力結果は次のようになりました。予測が4秒以下で終わっている事やaccuracyが9割を超えている点から、このモデルの凄さを実感しました。

学習後に予測に掛かった時間:  3.5827276706695557 , accuracy:  0.9516539440203562

おまけ

せっかくなので、サンプル数を変えた場合の時間やaccuracyを確認していきます。

train_test_splitメソッドの引数test_sizeの数値を変化させ、学習と予測に使用した上記4行のコードを実行していきます。

学習データを全データのうち20%(300未満のサンプル数)に設定

引数test_sizeを0.8に設定したところ、サンプル数は261個で、実行結果は次のとおりでした。

学習後に予測に掛かった時間:  3.9873201847076416 , accuracy:  0.9637404580152672

ほとんど学習と予測にかかる時間は変化していないようですが、accuracyが0.01ほど上がっています。コードで変えたところは引数test_sizeの数値だけですが、差が出ていますね。

学習データ全データのうち80%(1024超えのサンプル数)に設定

TabPFNは基本的に1024超えのサンプル数を使って学習をするようにはなっていません。理由としては、学習にかなりのメモリを使うので、あまりデータ数が多いと処理ができない場合があるためです。

しかし、fitメソッドで学習するときに次のように引数を一つ加えると可能です。今回はデータ数として多すぎる事がないように1024を少し超えるように設定してから学習させてみました。

classifier.fit(x_train, y_train, overwrite_warning=True)

引数test_sizeを0.8に設定したところ、サンプル数は1047個で、実行結果は次のとおりでした。

学習後に予測に掛かった時間:  11.173425197601318 , accuracy:  0.9885496183206107

試しに1024個を超えるデータを使ってみましたが、学習と予測に掛かる時間が3倍ほどになっています。そのぶん、accuracyが0.98と高い数値が出ています。

基本的には1024個以内にサンプル数が収まるようにしつつ、学習に掛けられる時間を確認しながらサンプル数を決めて渡すといいかもしれません。

終わりに

今回は爆速で動くTabPFNモデルを動かしてみました。今年度の論文という事で最近のモデルなのですが、とてつもなく早い上に高性能な感じでした。

C++を使う事でWhisperをCPU上でも高速に動作させるようなものも出ているようなので、今後は高速化の話がどんどん出てくるのかもしれませんね。

今回はここまで。

それでは、また!

参考にしたサイト