Amazon SageMakerでもRが使いたい!!
こんにちは、小澤です。
データ分析をするための言語としてよく名前が挙がるのがPython, R(とJulia)あたりでしょうか。 どちらの言語が良いとか悪いとか、そういった話も時々見かけますがそんなことは置いておいて、私は個人的にRが好きです(※じゃあPythonは嫌いかと言われるとそんなことはない)。 好きなので、機械学習なんかをやるときにもとりあえずRで(いや、Alteryxで)EDAを...となることも多々あります。 あるのですが、SageMakerを利用していると基本的に利用する言語はPythonとなります。
はい、今回はそんな時でもRを使いたいというわがままな私のための仕組みを紹介します。
Jupyter NotebookとRカーネル
SageMakerで機械学習プログラムを実装する際、ノートブックインスタンを使います。 このノートブックインスタンではJupyer Notebookを使ってプログラムを記述していくわけですが、これ自体はPythonしか扱えないというわけではありません。 各種カーネルを入れることで、他の言語も利用可能です。
そして、SageMakerのノートブックインスタンスには最初からRカーネルが用意されているのです!!
ステキですね。 でも、Pythonを使う場合はSageMaker Python SDKを使ってプログラムを書いていましたが、Rを使う場合はどうすればいいのでしょうか? SageMaker R SDKのようなものが用意されているとは聞いたことがありません。
その答えは、Rのライブラリにあります。 Rのreticulateパッケージを使うことで、Pythonの処理をそのまま呼び出し可能なのです。
SageMakerでRカーネルのノートブックを作成すると、このパッケージがデフォルトで入った状態となります。 そのため、SageMaker Python SDKをそのままRでも利用可能なのです。
実際に使ってみる
では、実際に使ってみます。 今回は、SageMakerのExampleにある Advanced Functionality > using_r_with_amazon_sagemaker を例に見ていきましょう。
まずは、SageMaker Python SDKを使えるようにしましょう。
library(reticulate) sagemaker <- import('sagemaker')
これだけです。 先ほど話題に上げたreticulateパッケージに含まれるimport関数を利用しています。
続いて、これを使ってSageMakerのSessionの取得と利用するデフォルトのS3バケット名, IAMロールを取得します。
session <- sagemaker$Session() bucket <- session$default_bucket() role_arn <- sagemaker$get_execution_role()
ここまでの処理の流れは以下のようなPythonでの記述と同じような感じになります。
import sagemaker session = sagemaker.Session() bucket = session.default_bucket() role_arn = sagemaker.get_execution_role()
最初にreticulateのimport関数にて、Pythonのimport文を実行します。 そのままimportできるわけではないのでその結果を変数に入れることで、Rから使えるようにしているわけです。
これをPythonのimportとの対応関係で考えると
hoge <- import(hoge)
→import hoge
piyo <- import(hoge)
→import hoge as fuga
fuga <- import(hoge.fuga)
→from hoge import fuga
のような対応関係となります。 さて、このようにして取得したPythonオブジェクトの関数を呼び出す際には「.」ではなく「$」を利用します。 Sessionやデフォルトのパケット名の取得にはこの仕組みを利用しているわけです。
ここまでわかれば普段RもSageMakerも使っている人にはほぼほぼ「大体全部理解した」状態ですね。
SageMakerでは、学習の際に利用するデータがS3に保存されていればいいので、どのようなデータを保存するかという部分に関しては、完全にRのプログラムとして実装できます。 その流れを見ていきましょう。
データはUCIにあるAbaloneデータセットを利用します。
まずは、データを取得します。 また、取得したデータには列名が含まれていないので別途設定しています。
library(readr) data_file <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/abalone/abalone.data' abalone <- read_csv(file = data_file, col_names = FALSE) names(abalone) <- c('sex', 'length', 'diameter', 'height', 'whole_weight', 'shucked_weight', 'viscera_weight', 'shell_weight', 'rings')
結果は以下のようにRのdata frameとして読み込まれます。 まぁ、普通にRのコードですね。
summary関数で要約情報を見たりggplot2を使った可視化もノートブック上で可能です。
abalone$sex <- as.factor(abalone$sex) summary(abalone)
library(ggplot2) ggplot(abalone, aes(x = height, y = rings, color = sex)) + geom_point() + geom_jitter()
dplyrを利用して、外れ値となる値を取り除き、性別をOne-hotなベクトルに変換しています。
library(dplyr) abalone <- abalone %>% filter(height != 0) abalone <- abalone %>% mutate(female = as.integer(ifelse(sex == 'F', 1, 0)), male = as.integer(ifelse(sex == 'M', 1, 0)), infant = as.integer(ifelse(sex == 'I', 1, 0))) %>% select(-sex) abalone <- abalone %>% select(rings:infant, length:shell_weight)
続いて、データをtrain, validation, testの3つに分割します。 学習データを7割り、残りを検証とテストで半分こです。 一番最後の処理は、SageMakerで利用するために目的変数を先頭列に持っていくものです。
abalone_train <- abalone %>% sample_frac(size = 0.7) abalone <- anti_join(abalone, abalone_train) abalone_test <- abalone %>% sample_frac(size = 0.5) abalone_valid <- anti_join(abalone, abalone_test)
必要なデータが揃ったので、これをS3に保存してSageMakerに渡すためのオブジェクトを生成します。
また、SageMakerに渡すファイルはヘッダなしのものとなるため、 col_name = FALSE
としている点にご注意ください。
write_csv(abalone_train, 'abalone_train.csv', col_names = FALSE) write_csv(abalone_valid, 'abalone_valid.csv', col_names = FALSE) s3_train <- session$upload_data(path = 'abalone_train.csv', bucket = bucket, key_prefix = 'r_kernel/abalone') s3_valid <- session$upload_data(path = 'abalone_valid.csv', bucket = bucket, key_prefix = 'r_kernel/abalone') s3_train_input <- sagemaker$s3_input(s3_data = s3_train, content_type = 'csv') s3_valid_input <- sagemaker$s3_input(s3_data = s3_valid, content_type = 'csv')
ここの処理でもreticulateを使ったPythonの関数呼び出しをしていますが、同じノリで使えいていることがわかりますね。 ここから先はSageMakerのEstimatorを利用しますが、使い方はこのノリでそのまま行けます。
# コンテナイメージの文字列作成 registry <- sagemaker$amazon$amazon_estimator$registry(session$boto_region_name, algorithm='xgboost') container <- paste(registry, '/xgboost:latest', sep='') s3_output <- paste0('s3://', bucket, '/r_kernel/output') # Estimatorのインスタンスを生成 # 必要な引数はPythonと同様 estimator <- sagemaker$estimator$Estimator(image_name = container, role = role_arn, train_instance_count = 1L, train_instance_type = 'ml.m5.large', train_volume_size = 30L, train_max_run = 3600L, input_mode = 'File', output_path = s3_output, output_kms_key = NULL, base_job_name = NULL, sagemaker_session = NULL) # ハイパーパラメータのやジョブ名の設定も同様 estimator$set_hyperparameters(num_round = 100L) job_name <- paste('sagemaker-train-xgboost', format(Sys.time(), '%H-%M-%S'), sep = '-') # fit関数呼び出し時にPythonではチャンネルとデータをdict形式で渡す # これをRから利用する場合には、list形式にする必要がある input_data <- list('train' = s3_train_input, 'validation' = s3_valid_input) # fit関数の呼び出しもPythonと同じ estimator$fit(inputs = input_data, job_name = job_name)
これで学習処理が実行されます。 マネジメントコンソール上からもPythonを使った時と同様に動いているのが確認できます。
推論エンドポイントの利用も同じノリです。
# エンドポイントの作成 model_endpoint <- estimator$deploy(initial_instance_count = 1L, instance_type = 'ml.t2.medium') # シリアライザの設定 model_endpoint$content_type <- 'text/csv' model_endpoint$serializer <- sagemaker$predictor$csv_serializer
さて、エンドポイントを利用した推論を行う際にはデータはS3に置くわけではなく、プログラムから直接渡します。 この部分では、Rのdata frameそのままではなく、maxrix型で渡してやるため、変換を行います。
# 目的変数の列の削除 abalone_test <- abalone_test[-1] # エンドポイントの最大件数にあわせてデータを500件に絞り、matrix型に変換する num_predict_rows <- 500 test_sample <- as.matrix(abalone_test[1:num_predict_rows, ]) # 列名の削除 dimnames(test_sample)[[2]] <- NULL
このデータを推論エンドポイントに投げます。
library(stringr) # エンドポイントにデータを投げて推論を行う predictions <- model_endpoint$predict(test_sample) # 結果はカンマ区切りに文字列して帰ってくるので、分割して数値に変換 predictions <- str_split(predictions, pattern = ',', simplify = TRUE) predictions <- as.numeric(predictions) # テストデータにこの推論結果を列として付与 abalone_test <- cbind(predicted_rings = predictions, abalone_test[1:num_predict_rows, ])
これで、以下のように推論結果が付与されたデータが得られます。
最後に、不要な場合は忘れずにエンドポイントを削除しておきましょう。
session$delete_endpoint(model_endpoint$endpoint)
おわりに
今回は、SageMakerノートブックインスタンのRカーネルとreticulateパッケージを使って、RからSageMakerを利用しました。
贅沢なことを言うと、マネージドなRStudion Serverが欲しいと思った今日この頃でしたw