SageMakerの新しいSDKとしてSageMaker Coreが発表されました。
こんちには。
データ事業本部 インテグレーション部 機械学習チームの中村( @nokomoro3 )です。
SageMakerの新しいSDKとしてSageMaker Coreが発表されました。
Pydanticを用いたモデルが多く使われており、パラメータ設定時のIDEの補完が聞きやすく、SageMakerのAPIとも1対1となり関連付けが分かりやすくなっているようです。
今回はこのSageMaker Coreと従来のSageMaker SDKを比較して、どのように変更されているかを確認していきたいと思います。
それぞれのドキュメントは以下をご確認ください。
- SageMaker SDK
- SageMaker Core
比較に使用する題材
機械学習のタスクとしては、XGBoostでirisデータセットを分類するモデルを使用します。
- SageMaker SDKの例
- SageMaker Coreの例
機械学習プロセスとしては以下を一通り行います。
- トレーニングジョブによるモデル学習
- モデル学習でできたものをエンドポイントにデプロイ
使用環境
実行環境は以下で行いました。
- WSL2 (Ubuntu 24.04 LTS) / Windows 10
- Python 3.12.3 / uv 0.4.15
Pythonパッケージは以下で行います(特にpydanticはバージョンが異なるとSageMaker Coreが動かないケースがありましたのでお気を付けください)。
sagemaker==2.230.0
sagemaker-core==1.0.3
pydantic==2.9.0
pydantic-core==2.23.2
scikit-learn==1.5.1
比較してみた
実行ロールの準備
あらかじめSageMakerの実行ロールを作成しておく必要があります。
信頼ポリシーは以下のようになります。
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": "sagemaker.amazonaws.com"
},
"Action": "sts:AssumeRole"
}
]
}
アタッチが必要なIAMポリシーは以下となります。
arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
初期設定
最初はsessionや実行ロール、バケット、コンテナイメージの取得する処理です。
- SageMaker SDK
import sagemaker
from sagemaker.session import Session, get_execution_role
boto_session = boto3.Session()
sagemaker_session = Session(boto_session=boto_session)
# role = get_execution_role(sagemaker_session=sagemaker_session) # SageMaker Studio内ではこちらでもOK
role = "{実行ロールのARN}"
print(role)
bucket = sagemaker_session.default_bucket()
print(bucket)
image = sagemaker.image_uris.retrieve(framework="xgboost", region="ap-northeast-1", version="latest")
print(image)
- SageMaker Core
import sagemaker
from sagemaker_core.helper.session_helper import Session, get_execution_role
boto_session = boto3.Session()
sagemaker_session = Session(boto_session=boto_session)
# role = get_execution_role(sagemaker_session=sagemaker_session) # SageMaker Studio内ではこちらでもOK
role = "{実行ロールのARN}"
print(role)
bucket = sagemaker_session.default_bucket()
print(bucket)
image = sagemaker.image_uris.retrieve(framework="xgboost", region="ap-northeast-1", version="latest")
print(image)
ほぼ違いはありませんが、SageMaker Coreには image_uris.retrieve
がまだないため、SageMaker SDKの関数を使用する必要がありました。この辺りは、SageMaker Coreの今後のアップデートにより改善が考えられるかもしれません。
また、今回はローカル環境(WSL2)から実施していますので、 get_execution_role
を使うことができないため直接実行ロールのARNを指定しています。
データセットの準備
こちらは全く同じ処理を用います。
scikit-learnからデータセットを取得し、trainとtestに分割してcsvを作成した後、S3にアップロードまでを行います。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
import os
# Get IRIS Data
iris = load_iris()
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
iris_df["target"] = iris.target
# Prepare Data
os.makedirs("./data", exist_ok=True)
iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]]
train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42)
train_data.to_csv("./data/train.csv", index=False, header=False)
# Upload Data
prefix = "DEMO-scikit-iris"
TRAIN_DATA = "train.csv"
DATA_DIRECTORY = "data"
train_input = sagemaker_session.upload_data(
DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY)
)
s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA)
s3_output_path = "s3://{}/{}/output".format(bucket, prefix)
print(s3_input_path)
print(s3_output_path)
また、推論用に正解カラムをドロップしたデータも作成しておきます。
test_data.to_csv("./data/test.csv", index=False, header=False)
# Remove the target column from the testing data. We will use this to call invoke_endpoint later
test_data_no_target = test_data.drop("target", axis=1)
トレーニングジョブの作成
ここからがいよいよ比較の本番です。まずはトレーニングジョブの比較になります。
- SageMaker SDK
from sagemaker.inputs import TrainingInput
estimator = sagemaker.estimator.Estimator(
sagemaker_session=sagemaker_session,
base_job_name="xgboost-iris",
image_uri=image,
role=role,
instance_count=1,
instance_type='ml.m4.xlarge',
volume_size=30,
output_path=s3_output_path,
max_run=600)
estimator.set_hyperparameters(
objective="multi:softmax",
num_class=3,
num_round=10,
eval_metric="merror")
train_input = TrainingInput(s3_input_path, content_type="csv")
estimator.fit({'train': train_input})
SageMaker SDKはEstimatorというクラスを中心に処理されます。様々なものが同じ関数の引数として与えられているため、何の目的のパラメータなのか推測が難しくなっているかと思います。
またfitの引数はdict型となっており、何が設定可能なのかを調べる必要があります。fitを実行すると、トレーニングジョブが作成される、というのも経験がないと推測が難しくなっています。
トレーニングジョブ名も base_job_name
を元に自動で付与されるため、少し制御が難しい印象があります。
- SageMaker Core
import time
from sagemaker_core.resources import (
TrainingJob,
AlgorithmSpecification,
Channel,
DataSource,
S3DataSource,
OutputDataConfig,
ResourceConfig,
StoppingCondition,
)
job_name = "xgboost-iris-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
training_job = TrainingJob.create(
training_job_name=job_name,
hyper_parameters={
"objective": "multi:softmax",
"num_class": "3",
"num_round": "10",
"eval_metric": "merror",
},
algorithm_specification=AlgorithmSpecification(
training_image=image, training_input_mode="File"
),
role_arn=role,
input_data_config=[
Channel(
channel_name="train",
content_type="csv",
compression_type="None",
record_wrapper_type="None",
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=s3_input_path,
s3_data_distribution_type="FullyReplicated",
)
),
)
],
output_data_config=OutputDataConfig(s3_output_path=s3_output_path),
resource_config=ResourceConfig(
instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=30
),
stopping_condition=StoppingCondition(max_runtime_in_seconds=600),
)
# Wait for TrainingJob with SageMakerCore
training_job.wait()
SageMaker Coreの場合、関数実行の記載としては少し長くなってしまいますが、Pydanticモデル経由で引数を渡すため、どのパラメータが何のためのものか分かりやすくなっています。
トレーニングジョブのcreateを実行すると、トレーニングジョブが作成される、ということも明確になっています。
トレーニングジョブ名も自分自身で任意に指定が可能となっているようです。
またcreate自体は非同期となっているため、処理待ち用の関数も準備されています。(SageMaker SDKでも引数に wait=False
を指定すれば非同期実行が可能ですが、その後待ち合わせする関数がありませんでした)
エンドポイントの作成
次はエンドポイント作成の比較です。
- SageMaker SDK
SageMaker SDKはかなりシンプルなコードで記述量が少ないですが、裏側ではCreateModel、CreateEndpointConfig、CreateEndopointがそれぞれ自動で実行されます。
より上位レイヤのAPIといったイメージでコード量は少なく使いやすいケースもあると思いますが、裏側のリソースを意識して使用する必要があります。
また一部のリソースを流用したい場合などの柔軟なケースでは、使用が難しくなりそうです。
from sagemaker.model import Model
from time import gmtime, strftime
model_data_url = estimator.model_data
key = f'xgboost-iris-{strftime("%H-%M-%S", gmtime())}'
print("Key:", key)
model = Model(name=key, image_uri=image, model_data=model_data_url, role=role)
model.deploy(
initial_instance_count=1,
instance_type="ml.m5.xlarge",
endpoint_name=key
)
- SageMaker Core
SageMaker Coreの場合は、各リソースごとに独立したクラスに分かれており、それぞれのcreateメソッドを呼ぶ形となっています。
透過性が高いAPIとなっており、各クラスを流用することも簡単にできるようになっています。その代わりにコード量は少し増える傾向がありそうです。
from sagemaker_core.shapes import ContainerDefinition, ProductionVariant
from sagemaker_core.resources import Model, EndpointConfig, Endpoint
from time import gmtime, strftime
# Get model_data_url from training_job object
model_data_url = training_job.model_artifacts.s3_model_artifacts
key = f'xgboost-iris-{strftime("%H-%M-%S", gmtime())}'
print("Key:", key)
model = Model.create(
model_name=key,
primary_container=ContainerDefinition(
image=image,
model_data_url=model_data_url,
),
execution_role_arn=role,
)
endpoint_config = EndpointConfig.create(
endpoint_config_name=key,
production_variants=[
ProductionVariant(
variant_name=key,
initial_instance_count=1,
instance_type="ml.m5.xlarge",
model_name=model, # Pass `Model`` object created above
)
],
)
endpoint: Endpoint = Endpoint.create(
endpoint_name=key,
endpoint_config_name=endpoint_config, # Pass `EndpointConfig` object created above
)
endpoint.wait_for_status("InService")
エンドポイントでの推論
次に作成したエンドポイントで推論を実施する部分を比較します。
- SageMaker SDK
SageMaker SDKでは、Predictorというクラスを使う必要があります。
エンドポイント作成後にModelクラスからエンドポイント名を受け取ってPredictorクラスを作成する必要がありますので、リソースと直感的に紐づけにくく、仕様を理解しながら使う必要がありました。
シリアライザとデータフレームの関係も少し直接的ではないため分かりにくいかと思います。
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer
predictor = Predictor(endpoint_name=model.endpoint_name, serializer=CSVSerializer())
response = predictor.predict(test_data_no_target.values)
print("Endpoint Response:", response.decode("utf-8"))
- SageMaker Core
SageMaker Coreでは、createの戻り値として得られるendpointにinvokeする形で推論を呼ぶことができ、エンドポイントに対して推論リクエストを投げることと直感的に紐づけやすいです。
またシリアライザも推論処理とは分けられているため、関係性も分かりやすくなっているかと思います。
from sagemaker.base_serializers import CSVSerializer
from sagemaker.deserializers import CSVDeserializer
deserializer = CSVDeserializer()
serializer = CSVSerializer()
invoke_result = endpoint.invoke(
body=serializer.serialize(test_data_no_target),
content_type="text/csv",
accept="text/csv",
)
deserialized_result = deserializer.deserialize(invoke_result["Body"], invoke_result["ContentType"])
print("Endpoint Response:", deserialized_result)
エンドポイントの削除
最後にエンドポイントの削除をしておきます。ここはコード的には大きな差分はなさそうです。
- SageMaker SDK
predictor.delete_endpoint()
model.delete_model()
- SageMaker Core
endpoint.delete()
endpoint.wait_for_delete()
endpoint_config.delete()
model.delete()
まとめ
いかがでしたでしょうか。元々のSageMaker SDKは慣れるために経験が必要な部分が多かったのですが、SageMaker Coreの登場によってかなり使いやすくなったと感じました。
まだ発表されたばかりですので、今後の機能拡充などにも期待したいと思います。
本記事がSageMakerをお使いになられる方の参考になれば幸いです。