Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,7 +2508,7 @@ def _build_single_modelbuilder(
containers=[container_def],
enable_network_isolation=True,
tags=[
{"key": "sagemaker-studio:jumpstart-model-id",
{"key": "sagemaker-sdk:jumpstart-model-id",
"value": base_model.hub_content_name},
],
)
Expand Down Expand Up @@ -4848,10 +4848,10 @@ def _deploy_nova_model(
)

tags = [
{"key": "sagemaker-studio:jumpstart-model-id", "value": base_model.hub_content_name},
{"key": "sagemaker-sdk:jumpstart-model-id", "value": base_model.hub_content_name},
]
if base_model.recipe_name:
tags.append({"key": "sagemaker-studio:recipe-name", "value": base_model.recipe_name})
tags.append({"key": "sagemaker-sdk:recipe-name", "value": base_model.recipe_name})

endpoint = Endpoint.create(
endpoint_name=endpoint_name,
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sagemaker.core.resources import TrainingJob, ModelPackageGroup, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.utils import _get_unique_name, _get_jumpstart_tags
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
_get_fine_tuning_options_and_model_arn,
Expand Down Expand Up @@ -251,7 +251,7 @@ def train(self,
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
tags = _get_jumpstart_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/multi_turn_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from sagemaker.train.common_utils.recipe_utils import _list_hub_models_by_recipe, _is_nova_model
from sagemaker.train.constants import get_sagemaker_hub_name
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.utils import _get_unique_name, _get_jumpstart_tags

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -279,7 +279,7 @@ def train(
self.training_dataset = training_dataset
job_config_doc = self._build_job_config_document()

tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
tags = _get_jumpstart_tags(self._model_name, get_sagemaker_hub_name())

try:
job = Job.create(
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.utils import _get_unique_name, _get_jumpstart_tags
from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
Expand Down Expand Up @@ -268,7 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
tags = _get_jumpstart_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.utils import _get_unique_name, _get_jumpstart_tags
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.train.configs import StoppingCondition
Expand Down Expand Up @@ -259,7 +259,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
tags = _get_jumpstart_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sagemaker.core.resources import TrainingJob, ModelPackageGroup, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.utils import _get_unique_name, _get_jumpstart_tags
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
Expand Down Expand Up @@ -250,7 +250,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())
tags = _get_jumpstart_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
6 changes: 3 additions & 3 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,14 @@ def _run_clone_command_silent(repo_url, dest_dir):
logger.error(f"Error output:\n{e}")
raise

def _get_studio_tags(model_id: str, hub_name: str):
def _get_jumpstart_tags(model_id: str, hub_name: str):
return [
{
"key": "sagemaker-studio:jumpstart-model-id",
"key": "sagemaker-sdk:jumpstart-model-id",
"value": model_id
},
{
"key": "sagemaker-studio:jumpstart-hub-name",
"key": "sagemaker-sdk:jumpstart-hub-name",
"value": hub_name
}
]
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/tests/unit/train/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf
mock_training_job_create.assert_called_once()
call_kwargs = mock_training_job_create.call_args[1]
assert call_kwargs["tags"] == [
{"key": "sagemaker-studio:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-studio:jumpstart-hub-name", "value": "SageMakerPublicHub"}
{"key": "sagemaker-sdk:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-sdk:jumpstart-hub-name", "value": "SageMakerPublicHub"}
]

@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/tests/unit/train/test_rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf
mock_training_job_create.assert_called_once()
call_kwargs = mock_training_job_create.call_args[1]
assert call_kwargs["tags"] == [
{"key": "sagemaker-studio:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-studio:jumpstart-hub-name", "value": "SageMakerPublicHub"}
{"key": "sagemaker-sdk:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-sdk:jumpstart-hub-name", "value": "SageMakerPublicHub"}
]

@patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/tests/unit/train/test_rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf
mock_training_job_create.assert_called_once()
call_kwargs = mock_training_job_create.call_args[1]
assert call_kwargs["tags"] == [
{"key": "sagemaker-studio:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-studio:jumpstart-hub-name", "value": "SageMakerPublicHub"}
{"key": "sagemaker-sdk:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-sdk:jumpstart-hub-name", "value": "SageMakerPublicHub"}
]

@patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group')
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/tests/unit/train/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def test_train_with_tags(self, mock_training_job_create, mock_model_package_conf
mock_training_job_create.assert_called_once()
call_kwargs = mock_training_job_create.call_args[1]
assert call_kwargs["tags"] == [
{"key": "sagemaker-studio:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-studio:jumpstart-hub-name", "value": "SageMakerPublicHub"}
{"key": "sagemaker-sdk:jumpstart-model-id", "value": "test-model"},
{"key": "sagemaker-sdk:jumpstart-hub-name", "value": "SageMakerPublicHub"}
]

def test_process_hyperparameters_removes_constructor_handled_keys(self):
Expand Down
Loading