Amazon MLのリアルタイムAPIをJavaのプログラムから呼び出す
事前準備
まずは機械学習で使うデータを登録します。今回はチュートリアルで使われている banking.csv を使って
- Datasource
- ML model
- Evaluation
を作成します。作成方法は以下のページが参考になります。
次にリアルタイムAPIを有効にする必要があります。これはManagement Consoleで設定します。 作成したML model summaryのページの下の方にPredictionsという項目がありますが、 Enable real-time predictions のCreate endpointボタンを押下します。
実装する
準備ができたので実装します。まずは以下のページからAWS SDK for Javaをダウンロードしてください。
ダウンロードしたjarファイルにビルドパスを通したら以下のように書いてみてください。 アクセスキーとシークレットキー、ML ModelのID、パラメータの箇所は書き換えてください。ハイライトの行です。
import java.util.HashMap; import java.util.Map; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.services.machinelearning.AmazonMachineLearningClient; import com.amazonaws.services.machinelearning.model.GetMLModelRequest; import com.amazonaws.services.machinelearning.model.GetMLModelResult; import com.amazonaws.services.machinelearning.model.PredictRequest; import com.amazonaws.services.machinelearning.model.PredictResult; public class Main { public static void main(String[] args) { String accessKey = "[Access Key]"; String secretKey = "[Secret Key]"; AWSCredentials credentials = new BasicAWSCredentials(accessKey, secretKey); AmazonMachineLearningClient client = new AmazonMachineLearningClient(credentials); String modelId = "[ML Model ID]"; GetMLModelRequest modelRequest = new GetMLModelRequest().withMLModelId(modelId); GetMLModelResult model = client.getMLModel(modelRequest); String predictEndpoint = model.getEndpointInfo().getEndpointUrl(); Map<String, String> input = new HashMap<String, String>(); input.put("age", "38"); input.put("job", "technician"); input.put("marital", "married"); input.put("education", "high.school"); PredictRequest predictRequest = new PredictRequest() .withMLModelId(modelId) .withPredictEndpoint(predictEndpoint) .withRecord(input); PredictResult prediction = client.predict(predictRequest); Map<String, Float> scores = prediction.getPrediction().getPredictedScores(); System.out.println(scores.get("0")); } }
実行したところ以下のようにスコアが表示されました。 この例ではinputというMap型のオブジェクトにパラメータを設定しています。パラメータを変えるとスコアも変わるので試してみてください。
0.37839836