# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access
import logging
from typing import Any, Dict, Optional, cast
from typing_extensions import Literal
from azure.ai.ml._restclient.v2023_04_01_preview.models import AllNodes
from azure.ai.ml._restclient.v2023_04_01_preview.models import JobService as RestJobService
from azure.ai.ml.constants._job.job import JobServiceTypeNames
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
module_logger = logging.getLogger(__name__)
class JobServiceBase(RestTranslatableMixin, DictMixin):
"""Base class for job service configuration.
This class should not be instantiated directly. Instead, use one of its subclasses.
:keyword endpoint: The endpoint URL.
:paramtype endpoint: Optional[str]
:keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
:paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
:keyword port: The port for the endpoint.
:paramtype port: Optional[int]
:keyword nodes: Indicates whether the service has to run in all nodes.
:paramtype nodes: Optional[Literal["all"]]
:keyword properties: Additional properties to set on the endpoint.
:paramtype properties: Optional[dict[str, str]]
:keyword status: The status of the endpoint.
:paramtype status: Optional[str]
:keyword kwargs: A dictionary of additional configuration parameters.
:paramtype kwargs: dict
"""
def __init__( # pylint: disable=unused-argument
self,
*,
endpoint: Optional[str] = None,
type: Optional[ # pylint: disable=redefined-builtin
Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]
] = None,
nodes: Optional[Literal["all"]] = None,
status: Optional[str] = None,
port: Optional[int] = None,
properties: Optional[Dict[str, str]] = None,
**kwargs: Dict,
) -> None:
self.endpoint = endpoint
self.type: Any = type
self.nodes = nodes
self.status = status
self.port = port
self.properties = properties
self._validate_nodes()
self._validate_type_name()
def _validate_nodes(self) -> None:
if not self.nodes in ["all", None]:
msg = f"nodes should be either 'all' or None, but received '{self.nodes}'."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.JOB,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
def _validate_type_name(self) -> None:
if self.type and not self.type in JobServiceTypeNames.ENTITY_TO_REST:
msg = (
f"type should be one of " f"{JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC}, but received '{self.type}'."
)
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.JOB,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
def _to_rest_job_service(self, updated_properties: Optional[Dict[str, str]] = None) -> RestJobService:
return RestJobService(
endpoint=self.endpoint,
job_service_type=JobServiceTypeNames.ENTITY_TO_REST.get(self.type, None) if self.type else None,
nodes=AllNodes() if self.nodes else None,
status=self.status,
port=self.port,
properties=updated_properties if updated_properties else self.properties,
)
@classmethod
def _to_rest_job_services(
cls,
services: Optional[Dict],
) -> Optional[Dict[str, RestJobService]]:
if services is None:
return None
return {name: service._to_rest_object() for name, service in services.items()}
@classmethod
def _from_rest_job_service_object(cls, obj: RestJobService) -> "JobServiceBase":
return cls(
endpoint=obj.endpoint,
type=(
JobServiceTypeNames.REST_TO_ENTITY.get(obj.job_service_type, None) # type: ignore[arg-type]
if obj.job_service_type
else None
),
nodes="all" if obj.nodes else None,
status=obj.status,
port=obj.port,
# ssh_public_keys=_get_property(obj.properties, "sshPublicKeys"),
properties=obj.properties,
)
@classmethod
def _from_rest_job_services(cls, services: Dict[str, RestJobService]) -> Dict:
# """Resolve Dict[str, RestJobService] to Dict[str, Specific JobService]"""
if services is None:
return None
result: dict = {}
for name, service in services.items():
if service.job_service_type == JobServiceTypeNames.RestNames.JUPYTER_LAB:
result[name] = JupyterLabJobService._from_rest_object(service)
elif service.job_service_type == JobServiceTypeNames.RestNames.SSH:
result[name] = SshJobService._from_rest_object(service)
elif service.job_service_type == JobServiceTypeNames.RestNames.TENSOR_BOARD:
result[name] = TensorBoardJobService._from_rest_object(service)
elif service.job_service_type == JobServiceTypeNames.RestNames.VS_CODE:
result[name] = VsCodeJobService._from_rest_object(service)
else:
result[name] = JobService._from_rest_object(service)
return result
[docs]
class JobService(JobServiceBase):
"""Basic job service configuration for backward compatibility.
This class is not intended to be used directly. Instead, use one of its subclasses specific to your job type.
:keyword endpoint: The endpoint URL.
:paramtype endpoint: Optional[str]
:keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
:paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
:keyword port: The port for the endpoint.
:paramtype port: Optional[int]
:keyword nodes: Indicates whether the service has to run in all nodes.
:paramtype nodes: Optional[Literal["all"]]
:keyword properties: Additional properties to set on the endpoint.
:paramtype properties: Optional[dict[str, str]]
:keyword status: The status of the endpoint.
:paramtype status: Optional[str]
:keyword kwargs: A dictionary of additional configuration parameters.
:paramtype kwargs: dict
"""
@classmethod
def _from_rest_object(cls, obj: RestJobService) -> "JobService":
return cast(JobService, cls._from_rest_job_service_object(obj))
def _to_rest_object(self) -> RestJobService:
return self._to_rest_job_service()
[docs]
class SshJobService(JobServiceBase):
"""SSH job service configuration.
:ivar type: Specifies the type of job service. Set automatically to "ssh" for this class.
:vartype type: str
:keyword endpoint: The endpoint URL.
:paramtype endpoint: Optional[str]
:keyword port: The port for the endpoint.
:paramtype port: Optional[int]
:keyword nodes: Indicates whether the service has to run in all nodes.
:paramtype nodes: Optional[Literal["all"]]
:keyword properties: Additional properties to set on the endpoint.
:paramtype properties: Optional[dict[str, str]]
:keyword status: The status of the endpoint.
:paramtype status: Optional[str]
:keyword ssh_public_keys: The SSH Public Key to access the job container.
:paramtype ssh_public_keys: Optional[str]
:keyword kwargs: A dictionary of additional configuration parameters.
:paramtype kwargs: dict
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START ssh_job_service_configuration]
:end-before: [END ssh_job_service_configuration]
:language: python
:dedent: 8
:caption: Configuring a SshJobService configuration on a command job.
"""
def __init__(
self,
*,
endpoint: Optional[str] = None,
nodes: Optional[Literal["all"]] = None,
status: Optional[str] = None,
port: Optional[int] = None,
ssh_public_keys: Optional[str] = None,
properties: Optional[Dict[str, str]] = None,
**kwargs: Any, # pylint: disable=unused-argument
) -> None:
super().__init__(
endpoint=endpoint,
nodes=nodes,
status=status,
port=port,
properties=properties,
**kwargs,
)
self.type = JobServiceTypeNames.EntityNames.SSH
self.ssh_public_keys = ssh_public_keys
@classmethod
def _from_rest_object(cls, obj: RestJobService) -> "SshJobService":
ssh_job_service = cast(SshJobService, cls._from_rest_job_service_object(obj))
ssh_job_service.ssh_public_keys = _get_property(obj.properties, "sshPublicKeys")
return ssh_job_service
def _to_rest_object(self) -> RestJobService:
updated_properties = _append_or_update_properties(self.properties, "sshPublicKeys", self.ssh_public_keys)
return self._to_rest_job_service(updated_properties)
[docs]
class TensorBoardJobService(JobServiceBase):
"""TensorBoard job service configuration.
:ivar type: Specifies the type of job service. Set automatically to "tensor_board" for this class.
:vartype type: str
:keyword endpoint: The endpoint URL.
:paramtype endpoint: Optional[str]
:keyword port: The port for the endpoint.
:paramtype port: Optional[int]
:keyword nodes: Indicates whether the service has to run in all nodes.
:paramtype nodes: Optional[Literal["all"]]
:keyword properties: Additional properties to set on the endpoint.
:paramtype properties: Optional[dict[str, str]]
:keyword status: The status of the endpoint.
:paramtype status: Optional[str]
:keyword log_dir: The directory path for the log file.
:paramtype log_dir: Optional[str]
:keyword kwargs: A dictionary of additional configuration parameters.
:paramtype kwargs: dict
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START ssh_job_service_configuration]
:end-before: [END ssh_job_service_configuration]
:language: python
:dedent: 8
:caption: Configuring TensorBoardJobService configuration on a command job.
"""
def __init__(
self,
*,
endpoint: Optional[str] = None,
nodes: Optional[Literal["all"]] = None,
status: Optional[str] = None,
port: Optional[int] = None,
log_dir: Optional[str] = None,
properties: Optional[Dict[str, str]] = None,
**kwargs: Any, # pylint: disable=unused-argument
) -> None:
super().__init__(
endpoint=endpoint,
nodes=nodes,
status=status,
port=port,
properties=properties,
**kwargs,
)
self.type = JobServiceTypeNames.EntityNames.TENSOR_BOARD
self.log_dir = log_dir
@classmethod
def _from_rest_object(cls, obj: RestJobService) -> "TensorBoardJobService":
tensorboard_job_Service = cast(TensorBoardJobService, cls._from_rest_job_service_object(obj))
tensorboard_job_Service.log_dir = _get_property(obj.properties, "logDir")
return tensorboard_job_Service
def _to_rest_object(self) -> RestJobService:
updated_properties = _append_or_update_properties(self.properties, "logDir", self.log_dir)
return self._to_rest_job_service(updated_properties)
[docs]
class JupyterLabJobService(JobServiceBase):
"""JupyterLab job service configuration.
:ivar type: Specifies the type of job service. Set automatically to "jupyter_lab" for this class.
:vartype type: str
:keyword endpoint: The endpoint URL.
:paramtype endpoint: Optional[str]
:keyword port: The port for the endpoint.
:paramtype port: Optional[int]
:keyword nodes: Indicates whether the service has to run in all nodes.
:paramtype nodes: Optional[Literal["all"]]
:keyword properties: Additional properties to set on the endpoint.
:paramtype properties: Optional[dict[str, str]]
:keyword status: The status of the endpoint.
:paramtype status: Optional[str]
:keyword kwargs: A dictionary of additional configuration parameters.
:paramtype kwargs: dict
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START ssh_job_service_configuration]
:end-before: [END ssh_job_service_configuration]
:language: python
:dedent: 8
:caption: Configuring JupyterLabJobService configuration on a command job.
"""
def __init__(
self,
*,
endpoint: Optional[str] = None,
nodes: Optional[Literal["all"]] = None,
status: Optional[str] = None,
port: Optional[int] = None,
properties: Optional[Dict[str, str]] = None,
**kwargs: Any, # pylint: disable=unused-argument
) -> None:
super().__init__(
endpoint=endpoint,
nodes=nodes,
status=status,
port=port,
properties=properties,
**kwargs,
)
self.type = JobServiceTypeNames.EntityNames.JUPYTER_LAB
@classmethod
def _from_rest_object(cls, obj: RestJobService) -> "JupyterLabJobService":
return cast(JupyterLabJobService, cls._from_rest_job_service_object(obj))
def _to_rest_object(self) -> RestJobService:
return self._to_rest_job_service()
[docs]
class VsCodeJobService(JobServiceBase):
"""VS Code job service configuration.
:ivar type: Specifies the type of job service. Set automatically to "vs_code" for this class.
:vartype type: str
:keyword endpoint: The endpoint URL.
:paramtype endpoint: Optional[str]
:keyword port: The port for the endpoint.
:paramtype port: Optional[int]
:keyword nodes: Indicates whether the service has to run in all nodes.
:paramtype nodes: Optional[Literal["all"]]
:keyword properties: Additional properties to set on the endpoint.
:paramtype properties: Optional[dict[str, str]]
:keyword status: The status of the endpoint.
:paramtype status: Optional[str]
:keyword kwargs: A dictionary of additional configuration parameters.
:paramtype kwargs: dict
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START ssh_job_service_configuration]
:end-before: [END ssh_job_service_configuration]
:language: python
:dedent: 8
:caption: Configuring a VsCodeJobService configuration on a command job.
"""
def __init__(
self,
*,
endpoint: Optional[str] = None,
nodes: Optional[Literal["all"]] = None,
status: Optional[str] = None,
port: Optional[int] = None,
properties: Optional[Dict[str, str]] = None,
**kwargs: Any, # pylint: disable=unused-argument
) -> None:
super().__init__(
endpoint=endpoint,
nodes=nodes,
status=status,
port=port,
properties=properties,
**kwargs,
)
self.type = JobServiceTypeNames.EntityNames.VS_CODE
@classmethod
def _from_rest_object(cls, obj: RestJobService) -> "VsCodeJobService":
return cast(VsCodeJobService, cls._from_rest_job_service_object(obj))
def _to_rest_object(self) -> RestJobService:
return self._to_rest_job_service()
def _append_or_update_properties(
properties: Optional[Dict[str, str]], key: str, value: Optional[str]
) -> Dict[str, str]:
if value and not properties:
properties = {key: value}
if value and properties:
properties.update({key: value})
return properties if properties is not None else {}
def _get_property(properties: Dict[str, str], key: str) -> Optional[str]:
return properties.get(key, None) if properties else None