Amazon MLのリアルタイムAPIをJavaのプログラムから呼び出す

2016.12.12

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

事前準備

まずは機械学習で使うデータを登録します。今回はチュートリアルで使われている banking.csv を使って

  • Datasource
  • ML model
  • Evaluation

を作成します。作成方法は以下のページが参考になります。

Amazon Machine Learningを試してみた

次にリアルタイムAPIを有効にする必要があります。これはManagement Consoleで設定します。 作成したML model summaryのページの下の方にPredictionsという項目がありますが、 Enable real-time predictions のCreate endpointボタンを押下します。

amazon-ml-realtime-java

実装する

準備ができたので実装します。まずは以下のページからAWS SDK for Javaをダウンロードしてください。

AWS SDK for Java

ダウンロードしたjarファイルにビルドパスを通したら以下のように書いてみてください。 アクセスキーとシークレットキー、ML ModelのID、パラメータの箇所は書き換えてください。ハイライトの行です。

import java.util.HashMap;
import java.util.Map;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.machinelearning.AmazonMachineLearningClient;
import com.amazonaws.services.machinelearning.model.GetMLModelRequest;
import com.amazonaws.services.machinelearning.model.GetMLModelResult;
import com.amazonaws.services.machinelearning.model.PredictRequest;
import com.amazonaws.services.machinelearning.model.PredictResult;

public class Main {

	public static void main(String[] args) {

		String accessKey = "[Access Key]";
		String secretKey = "[Secret Key]";
		AWSCredentials credentials = new BasicAWSCredentials(accessKey, secretKey);
		AmazonMachineLearningClient client = new AmazonMachineLearningClient(credentials);

		String modelId = "[ML Model ID]";
		GetMLModelRequest modelRequest = new GetMLModelRequest().withMLModelId(modelId);
		GetMLModelResult model = client.getMLModel(modelRequest);
		String predictEndpoint = model.getEndpointInfo().getEndpointUrl();
		
		Map<String, String> input = new HashMap<String, String>();
		input.put("age", "38");
		input.put("job", "technician");
		input.put("marital", "married");
		input.put("education", "high.school");

		PredictRequest predictRequest = new PredictRequest()
		   .withMLModelId(modelId)
		   .withPredictEndpoint(predictEndpoint)
		   .withRecord(input);

		PredictResult prediction = client.predict(predictRequest);
		Map<String, Float> scores = prediction.getPrediction().getPredictedScores();
		System.out.println(scores.get("0"));
	}
}

実行したところ以下のようにスコアが表示されました。 この例ではinputというMap型のオブジェクトにパラメータを設定しています。パラメータを変えるとスコアも変わるので試してみてください。

0.37839836