この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
平田です。
AMLのManagement Console画面は非常にわかりやすく、モデル構築などを行う上では不満はないのですが、流石に何度も構築作業を繰り返すと、段々と飽きが来てしまいます。
特に最近、AMLでたくさんモデル構築を行う機会があったのですが、Management Consoleからひたすらポチポチと同じ作業を繰り返すのは辛い上にミスを誘発するので、あまりいいやり方ではありませんでした。
ということで今回は、AWS APIを使った構築の自動化を目指すべく、AWS SDKを使って、Amazon MLのTutorialをスクリプトで実行してみました。
Management Consoleとの違い
作成したコードは本記事の末尾に記載してあります。
ここでは、Management Consoleから作成した場合との違いや注意点などを紹介していきたいと思います。
schemaの存在
Management ConsoleでDataSourceを作成する際に、読み込んだデータから各項目の形式(CATEGORICAL, NUMERICなど)を設定しますが、APIを利用する場合にはschemaと呼ばれるデータ形式の定義を渡してあげる必要があります。
schemaはjsonで書かれており、各項目のデータ形式のほか、ターゲット変数の指定や1行目をヘッダとして読み込むかなどの設定も指定できます。
schemaの指定方法は、直接APIのパラメータに文字列そのものを指定するほかに、S3バケット上に配置して、そのURIを指定する方法があります。今回はサンプルとして提供されているbanking.csv.schemaをS3バケットから取り込んでいます。
{
"version": "1.0",
"targetAttributeName": "y",
"dataFormat": "CSV",
"dataFileContainsHeader": true,
"attributes": [
{
"attributeName" : "age",
"attributeType": "NUMERIC"
},
{
"attributeName": "campaign",
"attributeType": "NUMERIC"
},
{
"attributeName": "cons_conf_idx",
"attributeType": "NUMERIC"
},
...
]
}
Data Splittingは事前に行う
Management Consoleでは、Model構築のタイミングで事前に作成したDataSourceを2分割し、一方を訓練用、もう一方を評価用として再度DataSourceを作成しなおします。
しかし、APIを利用する場合は、DataSource作成の段階ではじめから分割したDataSourfceとして準備できます。
分割の割合は、APIのパラメータに以下のようにjsonを指定することで、DataSource作成時に反映されます。 randomSeedを同じ値に設定すれば、同じ混ざり方でデータの分割ができるため、一つのデータソースから2分割してそれぞれをデータソースとすることができます。
{
"randomSeed":"random_seed_value",
"splitting":
{ "percentBegin":0,"percentEnd":70}
}
APIならRDSから接続できる
Management Consoleではデータソースの指定はredshiftかS3の2択になるのですが、実はAPIではもう一つ、RDSからの接続も用意されています。
RDSのインスタンス情報やユーザ情報、クエリをパラメータに渡すことで、redshiftと同様に直接DataSourceを作成することができます。
RDSからS3経由でDataSourceを作成している方は、是非一度お試しください!
AWS SDK for Ruby v2での注意点
今回はAWS SDK for Ruby v2を利用してみましたが、スクリプトを書いている時に一点だけ問題があったので、それについてもここで紹介します。
wait_untilが使えない
AWS APIの多くは非同期APIであるため、その場では結果は返ってきません。そのため、定期的にエンティティのステータスを監視し、処理が正常に終了したかを確認する必要があります。
AWS APIにはClient#wait_untilというエンティティの状態を監視するためのメソッドが用意されているのですが、Machine Learningに関してはまだ未実装(?)のようです。
client.waiter_names #=> []
なので、今回はステータスを定期的にチェックするよう自分で実装してお茶を濁しています。 もしこれからAWS APIをつかってMachine Learningの構築スクリプトを作成する人がいれば、この点注意して頂ければと思います。
まとめ
今回はとりあえずAMLの一通りの流れをAWS APIで動かしてみました。機械学習の予測精度を上げたい場合、いろいろなパターンのモデルをとにかく試してみる、という場面があります。
そんな時のために自動化のノウハウを今から積んでおきたいと思います。
作成したスクリプト
require 'aws-sdk'
require 'securerandom'
ml = Aws::MachineLearning::Client.new(region: 'us-east-1')
hash_value = SecureRandom.hex(8)
# 訓練用データセットの設定
ds_train = {
data_source_id: 'ds-train-' + hash_value,
data_source_name: 'ds-train-banking',
data_spec: {
data_location_s3: "s3://aml-sample-data/banking.csv",
data_schema_location_s3: "s3://aml-sample-data/banking.csv.schema",
data_rearrangement: "{\"randomSeed\":\"#{hash_value}\", \"splitting\":{\"percentBegin\":0,\"percentEnd\":70}}"
},
compute_statistics: true
}
# 評価用データセットの設定
ds_eval = {
data_source_id: 'ds-eval-' + hash_value,
data_source_name: 'ds-eval-banking',
data_spec: {
data_location_s3: "s3://aml-sample-data/banking.csv",
data_schema_location_s3: "s3://aml-sample-data/banking.csv.schema",
data_rearrangement: "{\"randomSeed\":\"#{hash_value}\", \"splitting\":{\"percentBegin\":70,\"percentEnd\":100}}"
},
compute_statistics: true
}
# バッチ予測用データセットの設定
ds_bp = {
data_source_id: 'ds-bp-' + hash_value,
data_source_name: 'ds-bp-banking',
data_spec: {
data_location_s3: "s3://aml-sample-data/banking-batch.csv",
data_schema_location_s3: "s3://aml-sample-data/banking.csv.schema"
},
compute_statistics: true
}
# モデルの設定
model = {
ml_model_id: 'model-' + hash_value,
ml_model_name: 'model-banking',
ml_model_type: 'BINARY',
training_data_source_id: ds_train[:data_source_id]
}
# 評価の設定
eval = {
evaluation_id: 'eval-' + hash_value,
evaluation_name: 'eval-banking',
ml_model_id: model[:ml_model_id], # required
evaluation_data_source_id: ds_eval[:data_source_id]
}
# バッチ予測の設定
bp = {
batch_prediction_id: 'bp-' + hash_value,
batch_prediction_name: "bp-banking",
ml_model_id: model[:ml_model_id],
batch_prediction_data_source_id: ds_bp[:data_source_id],
output_uri: "s3://cm-aml-test/result",
}
# データセット作成
ml.create_data_source_from_s3(ds_train)
ml.create_data_source_from_s3(ds_eval)
ml.create_data_source_from_s3(ds_bp)
# ml.wait_until(...)
# モデル構築
ml.create_ml_model(model)
# ml.wait_until(...)
# 評価作成
ml.create_evaluation(eval)
# ml.wait_until(...)
# バッチ予測
ml.create_batch_prediction(bp)