Amazon SageMakerでAWS IoT Analyticsにあるデータセットから異常検出モデルを学習させてみる – Amazon SageMaker Advent Calendar 2018
こんにちは、大阪DI部の大澤です。 この記事は「クラスメソッド Amazon SageMaker Advent Calendar」の25日目の記事となります。ついにクリスマスがやってきました。
今回はAmazon SageMakerで異常検出モデル(ランダムカットフォレスト)をAWS IoT Analyticsのデータセットで学習させてみます。 AWS IoTを介してデバイス(産業用機械)から定期的に送られてくるデータ(温度データ)をIoT Analyticsに蓄積し、SageMakerで異常検出モデルを作成するという流れの想定です。
やってみる
概要
- IoT Analyticsのチャネル、パイプライン、データストアを作る
- 産業用機械の温度データセットをチャネルに入れる
- データセットを学習用とテスト用の2つにわけて作る
- データセットを読み込んで前処理する
- ランダムカットフォレストで異常値検出モデルを作る
- バッチ変換でテスト用データセットを推論し、モデルの確認を行う
今回使用するデータセットについて
今回使用するデータセットは産業用機械の温度データで、NAB(The Numenta Anomaly Benchmark)という異常検知のベンチマーク用ソフトウェアで使用されているものです。
環境
- SageMakerのノートブックインスタンス
- Jupyter Lab
- カーネル: conda_python3
注意
今回は各処理のスクリプトを紹介します。紹介するスクリプトは以下のリポジトリに入っているものです。試す際には以下のリポジトリをクローンし、iot_analytics_for_blog/work.ipynb
を進めてもらえればと思います。work.ipynb
では記述された各JupyterノートブックをPapermillを使って実行します。iot_analytics_for_blog/notebook/
配下の各ノートブックに処理を記述しています。必要に応じてご参照ください。
前準備
AWS IoT Analytics関連のリソースを作る
AWS IoT Analytics関連のリソースを作成します。各リソースにおけるデータの保持期間は1日で設定します。
まずは各リソース名を定義します。
channel_name = 'channel-name' pipeline_name = 'pipeline-name' datastore_name = 'datastore-name' train_dataset_name = 'dataset_train' test_dataset_name = 'dataset_test'
boto3のIoT Analytics用のクライアントを取得します。
import boto3 iota_client = boto3.client('iotanalytics')
チャネルを作ります。チャネルは生データを保存するための場所です。
response = iota_client.create_channel( channelName=channel_name, retentionPeriod={ 'unlimited': False, 'numberOfDays': 1 } ) print(response)
データストアを作成します。データストアはチャネルからパイプラインを介して処理されたデータを保存する場所です。
response = iota_client.create_datastore( datastoreName=datastore_name, retentionPeriod={ 'unlimited': False, 'numberOfDays': 1 }, ) print(response)
パイプラインを作ります。パイプラインはチャネルからデータストアへの流れを制御できます。アクティビティとして処理を定義することで、簡易的なデータの変換処理を行えます。今回はチャネルに流れてきたデータをそのままデータストアに流すだけのパイプラインです。
datastore_activity_name = 'datastore_activity' response = iota_client.create_pipeline( pipelineName=pipeline_name, pipelineActivities=[ { 'channel': { 'name': 'channel_activity', 'channelName': channel_name, 'next': datastore_activity_name } }, { 'datastore': { 'name': datastore_activity_name, 'datastoreName': datastore_name }, } ] ) print(response)
次にデータセットを作成します。データストアからAthena互換のSQLクエリを使用してデータを抽出します。
今回は学習用とテスト用の2つのデータセットを作成します。データの抽出の際にはtimestamp
アトリビュートの境界点を定めて、その前後でデータセットを分けます。
bound_ts = '2013-12-10 00:00:00' dataset_creation_params = [ { 'dataset_name': train_dataset_name, 'action_name': 'train_query', 'sql_condition': 'WHERE CAST("timestamp" AS TIMESTAMP) < TIMESTAMP \'{}\''.format(bound_ts) },{ 'dataset_name': test_dataset_name, 'action_name': 'test_query', 'sql_condition': 'WHERE CAST("timestamp" AS TIMESTAMP) >= TIMESTAMP \'{}\''.format(bound_ts) } ] for param_dic in dataset_creation_params: response = iota_client.create_dataset( datasetName=param_dic['dataset_name'], actions=[ { 'actionName': param_dic['action_name'], 'queryAction': { 'sqlQuery': 'SELECT * FROM "{}" {}'.format( datastore_name, param_dic['sql_condition'] ) }, } ], retentionPeriod={ 'unlimited': False, 'numberOfDays': 1 } ) print(response)
AWS IoT Analytics チャネルにデータをいれる
先ほど作成したチャネルにデータを送り入れます。 今回のようにチャネルにデータを流し込む他に、AWS IoTのトピックからIoTルールでデバイスデータをチャネルへ流し込むことができます。
産業用機械の温度データのCSVがあるURIからダウンロードし、読み込んんだデータをチャネルに流し込みます。
import pandas as pd from more_itertools import chunked import boto3 # 温度データのURL csv_data_source = 'https://raw.githubusercontent.com/numenta/NAB/master/data/realKnownCause/machine_temperature_system_failure.csv'' # データをまとめるサイズ data_split_size = 100 iota_client = boto3.client('iotanalytics') # データを読み込む df = pd.read_csv(csv_data_source) # データを分割して、チャネルへ送り込む for df_sub in chunked(df.iterrows(), data_split_size): format_messages = [] for index, row in df_sub: format_messages.append(dict( messageId = str(index+1), payload = row.to_json() )) response = iota_client.batch_put_message(channelName=channel_name, messages=format_messages) print(response)
データセットの作成
学習用とテスト用データセットをそれぞれ作成するために、抽出処理開始のリクエストを投げます。 リクエストを投げてから数秒から数分後に、データの抽出処理が完了します。
import boto3 iota_client = boto3.client('iotanalytics') for dataset_name in [train_dataset_name, test_dataset_name]: response = iota_client.create_dataset_content(datasetName=dataset_name) print(response)
異常検出モデルの作成
AWS IoT データセットからデータを取り出す
データセットから学習用とテスト用のデータを取り出します。get_dataset_contentを実行することでデータセットのデータの署名つきURIを取得できます。そのURIからデータをダウンロードできます。
import boto3 from urllib.request import urlretrieve working_dir = './working/' train_dataset_path = path.join(working_dir, 'train_dataset.csv') test_dataset_path = path.join(working_dir, 'test_dataset.csv') dataset_params = [ dict( dataset_name = train_dataset_name, file_path = train_dataset_path ), dict( dataset_name = test_dataset_name, file_path = test_dataset_path ), ] iota_client = boto3.client('iotanalytics') for dataset_param in dataset_params: # データセットの情報を取得 response = iota_client.get_dataset_content(datasetName = dataset_param['dataset_name']) # データセットの署名つきURLを取りだす data_uri = response['entries'][0]['dataURI'] # ファイルデータを取り出す urlretrieve(data_uri, dataset_param['file_path'])
データの前処理
データをモデルの学習に実行できる形に変換します。この処理は学習用とテスト用それぞれで実行します。テスト用に実行する際にはneed_label_flg
をTrueにしてラベルデータを加える必要があります。
label_data_source = 'https://raw.githubusercontent.com/numenta/NAB/master/labels/raw/known_labels_v1.0.json' dataset_path = train_dataset_path # 処理するデータの保存場所 data_s3_path = 's3://bucket-name/sagemaker/iot-analytics/machine-temperature/train.csv' shingle_size = 12 * 24 # テスト用データの作成にはラベルデータが必要 need_label_flg = False
データをpandasのデータフレームとして読み込みます。
import pandas as pd df = pd.read_csv(dataset_path)
タイムスタンプをindexにし、不要なカラムを削除します。
df.index = pd.to_datetime(df.timestamp) df = df.drop(columns=['timestamp', '__dt']).sort_index()
1日分の温度データをまとめて1つのデータポイントとして扱うために、シングリング処理を施します。
def shingle(data, shingle_size): import numpy as np num_data = len(data) shingled_data = np.zeros((num_data-shingle_size, shingle_size)) for n in range(num_data - shingle_size): shingled_data[n] = data[n:(n+shingle_size)] return shingled_data shingled_data = shingle(df.value, shingle_size)
テストデータの場合にはラベルデータを付与します。
if need_label_flg: import numpy as np from urllib import request import json with request.urlopen(label_data_source) as f: label_data = json.loads(f.read().decode()) anomaly_dates = label_data['realKnownCause/machine_temperature_system_failure.csv'] anomaly_datetimes = [pd.to_datetime(dt) for dt in anomaly_dates] is_anomaly = [int(timestamp in anomaly_datetimes) for timestamp in df.index] df['is_anomaly'] = pd.Series(is_anomaly, index=df.index) # シングリングすることでシングルサイズ分のデータが無くなるので、データフレームも合わせておく shingled_df = df.iloc[shingle_size:] # 各行の先頭に異常値かどうかのラベルをつける(異常値:1, 正常値:0) labeled_data = [np.insert(row, 0, shingled_df.is_anomaly.iloc[i]) for i, row in enumerate(shingled_data)]
処理したデータをローカルに保存し、それをS3にアップロードします。
local_path = '/tmp/data.csv' import numpy as np np.savetxt( local_path, labeled_data if need_label_flg else shingled_data, delimiter=',' ) !aws s3 cp $local_path $data_s3_path
学習
SageMakerの組み込みアルゴリズムであるランダムカットフォレストのモデルを学習させます。
import sagemaker import boto3 from sagemaker.estimator import Estimator from sagemaker.amazon.amazon_estimator import get_image_uri execution_role = sagemaker.get_execution_role() model_artifact_path = 's3://bucket-name/model_artifact_path' base_job_name = 'job_name' # ランダムカットフォレスト用のコンテナイメージ training_image = get_image_uri(boto3.Session().region_name, 'randomcutforest') # 学習用処理の設定 rcf = Estimator( role=execution_role, train_instance_count=1, train_instance_type='ml.m4.xlarge', output_path=model_artifact_path, base_job_name=base_job_name, image_name=training_image ) # ハイパーパラメータの設定 hyperparameters = dict( num_samples_per_tree=256, num_trees=100, feature_dim=shingle_size ) rcf.set_hyperparameters(**hyperparameters) # 教師データ train_s3_data = sagemaker.s3_input( s3_data = train_data_s3_path, content_type = 'text/csv;label_size=0', distribution = 'ShardedByS3Key' ) # テストデータ test_s3_data = sagemaker.s3_input( s3_data = test_data_s3_path, content_type = 'text/csv;label_size=1', distribution = 'FullyReplicated' ) # 学習開始 rcf.fit({'train': train_s3_data, 'test': test_s3_data}, wait=True)
モデルの学習は数分で完了します。
モデルの確認
先ほど学習させた異常検出モデルの確認を行います。
- データをバッチ変換
- 変換結果を可視化
先ほどの学習ジョブ名やテストデータの場所などを確認しておきます。
training_job_name = rcf.latest_training_job.name labeled_test_data_s3_path = test_s3_data output_data_s3_path = 's3://bucket-name/machine_temperature_iot/transform'
バッチ変換用にデータ形式を変更
学習ジョブのテストデータに使っていたものをバッチ変換の入力として使える形式に変換し、S3にアップロードします。
import pandas as pd from os import path transform_input_data_local_path = '/tmp/test_transform.csv' pd.read_csv(labeled_test_data_s3_path, header=None)\ .drop(columns=0)\ .to_csv(transform_input_data_local_path, header=None, index=None) transform_input_data_s3_path = path.join(path.dirname(labeled_test_data_s3_path), 'test_transform.csv') !aws s3 cp $transform_input_data_local_path $transform_input_data_s3_path
バッチ変換
先ほど作成した入力データをバッチ変換するジョブを実行します。
from sagemaker.estimator import Estimator model = Estimator.attach(training_job_name=training_job_name) transformer = model.transformer( instance_count=1, instance_type='ml.m4.xlarge', strategy='MultiRecord', assemble_with='Line', output_path=output_data_s3_path ) transformer.transform( transform_input_data_s3_path, content_type='text/csv', split_type='Line' ) transformer.wait()
バッチ変換結果の取得
temp_path = '/tmp/transform_output/' !aws s3 sync $transformer.output_path $temp_path # バッチ変換の結果ファイルは入力ファイル名に.outが付いている output_path = path.join(temp_path, path.basename(transform_input_data_s3_path))+'.out' !head $output_path
異常度スコアを取り出して、リストを作成します。
import json with open(output_path) as f: lines = f.readlines() scores = list(map(lambda l : json.loads(l)['score'], lines))
温度データと異常度スコア、異常値かどうかのラベル、閾値用のカラムを持つデータフレームを作成します。
import pandas as pd df = pd.read_csv(labeled_test_data_s3_path, header=None) # 異常度スコアを入れる df_format = pd.DataFrame(data={ 'is_anomaly': df.iloc[:, 0], # 異常値かどうか 'value': df.iloc[:, shingle_size], # 温度の値 'score': scores, # 異常度スコア 'anomaly_threshold': [0]*len(df) # 異常度の閾値(この後計算する) })
異常度スコアの閾値の計算
異常度スコアの平均から1標準偏差以上離れた値を異常度として扱うことにします。
score_mean = df_format.score.mean() score_std = df_format.score.std() score_cutoff = score_mean + 1 * score_std df_format['anomaly_threshold'] = pd.Series([score_cutoff]*len(df_format), df_format.index)
結果を可視化
異常度スコアと温度データをグラフにプロットしてみます。
import numpy as np import matplotlib.pyplot as plt import seaborn as sns # pyplotで描画する図を綺麗にする sns.set() # 図がインラインで描画されるようにする(JupyterLabでは不要) %matplotlib inline fig, ax1 = plt.subplots() ax2 = ax1.twinx() ax1.plot(df_format.value, color='C0', alpha=0.7) ax2.plot(df_format.score, color='C1', alpha=0.7) # 異常値のラベルデータ anomalies_true = df_format[df_format.is_anomaly == 1] ax1.plot(anomalies_true.value, 'ko' ) # 推論による異常値 anomalies_infer = df_format.score[df_format.score >= score_cutoff] ax2.plot(anomalies_infer, 'ro' ) ax2.plot(df_format.anomaly_threshold, 'r', alpha=0.5) ax1.grid(which='major', axis='both') ax1.set_ylabel('Machine Temperature', color='C0') ax2.set_ylabel('Anomaly Score', color='C1') ax1.tick_params('y', colors='C0') ax2.tick_params('y', colors='C1') ax1.set_ylim(0, max(df_format.value)) ax2.set_ylim(min(df_format.score), 1.5*max(max(df_format.score), score_cutoff)+1) fig.set_figwidth(12) plt.show()
黒丸が異常値で赤丸が推論によって異常値と判定した点です。今回は異常度スコアの閾値を非常に低く設定したため、異常値でないポイントも多く異常値だと判定しています。
AWS IoT Analytics関連リソースの削除
作業が終わったので、不要になったIoT Analytics関連のリソースを削除します。
import boto3 iota_client = boto3.client('iotanalytics') # データセットの削除 for dataset_name in [train_dataset_name, test_dataset_name]: response = iota_client.delete_dataset( datasetName=dataset_name ) print(response) # パイプラインの削除 response = iota_client.delete_pipeline( pipelineName=pipeline_name ) print(response) # データストアの削除 response = iota_client.delete_datastore( datastoreName=datastore_name ) print(response) # チャネルの削除 response = iota_client.delete_channel( channelName=channel_name ) print(response)
さいごに
今回は 「クラスメソッド Amazon SageMaker Advent Calendar」 の25日目として、 Amazon SageMakerで異常検出モデルをAWS IoT Analyticsのデータセットで学習させてみた内容についてお伝えしました。IoT Analyticsを使うことでより簡単にデバイスデータの蓄積と処理が行え、分析へ繋げやすくなります。そういったところからSageMakerとIoT Analyticsを連携させることで、デバイスデータの蓄積から分析が楽になりそうです。
データインテグレーション部(DI部)の機械学習チームでお送りしてきました、「クラスメソッド Amazon SageMaker Advent Calendar 2018」も今日で終わりです。SageMakerを触ろうとしている方、すでに触っている方々の参考に少しでもなれば幸いです。 お読みくださりありがとうございました〜!Merry Christmas!
参考
- AWS IoT Analyticsをマネジメントコンソールから触る内容