Amazon MLのリアルタイムAPIをJavaのプログラムから呼び出す

Amazon Machine Learning

事前準備

まずは機械学習で使うデータを登録します。今回はチュートリアルで使われている banking.csv を使って

  • Datasource
  • ML model
  • Evaluation

を作成します。作成方法は以下のページが参考になります。

Amazon Machine Learningを試してみた

次にリアルタイムAPIを有効にする必要があります。これはManagement Consoleで設定します。 作成したML model summaryのページの下の方にPredictionsという項目がありますが、 Enable real-time predictions のCreate endpointボタンを押下します。

amazon-ml-realtime-java

実装する

準備ができたので実装します。まずは以下のページからAWS SDK for Javaをダウンロードしてください。

AWS SDK for Java

ダウンロードしたjarファイルにビルドパスを通したら以下のように書いてみてください。 アクセスキーとシークレットキー、ML ModelのID、パラメータの箇所は書き換えてください。ハイライトの行です。

import java.util.HashMap;
import java.util.Map;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.machinelearning.AmazonMachineLearningClient;
import com.amazonaws.services.machinelearning.model.GetMLModelRequest;
import com.amazonaws.services.machinelearning.model.GetMLModelResult;
import com.amazonaws.services.machinelearning.model.PredictRequest;
import com.amazonaws.services.machinelearning.model.PredictResult;

public class Main {

	public static void main(String[] args) {

		String accessKey = "[Access Key]";
		String secretKey = "[Secret Key]";
		AWSCredentials credentials = new BasicAWSCredentials(accessKey, secretKey);
		AmazonMachineLearningClient client = new AmazonMachineLearningClient(credentials);

		String modelId = "[ML Model ID]";
		GetMLModelRequest modelRequest = new GetMLModelRequest().withMLModelId(modelId);
		GetMLModelResult model = client.getMLModel(modelRequest);
		String predictEndpoint = model.getEndpointInfo().getEndpointUrl();
		
		Map<String, String> input = new HashMap<String, String>();
		input.put("age", "38");
		input.put("job", "technician");
		input.put("marital", "married");
		input.put("education", "high.school");

		PredictRequest predictRequest = new PredictRequest()
		   .withMLModelId(modelId)
		   .withPredictEndpoint(predictEndpoint)
		   .withRecord(input);

		PredictResult prediction = client.predict(predictRequest);
		Map<String, Float> scores = prediction.getPrediction().getPredictedScores();
		System.out.println(scores.get("0"));
	}
}

実行したところ以下のようにスコアが表示されました。 この例ではinputというMap型のオブジェクトにパラメータを設定しています。パラメータを変えるとスコアも変わるので試してみてください。

0.37839836