Source code for azure.ai.ml.entities._job.job_service

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

import logging
from typing import Any, Dict, Optional, cast

from typing_extensions import Literal

from azure.ai.ml._restclient.v2023_04_01_preview.models import AllNodes
from azure.ai.ml._restclient.v2023_04_01_preview.models import JobService as RestJobService
from azure.ai.ml.constants._job.job import JobServiceTypeNames
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException

module_logger = logging.getLogger(__name__)


class JobServiceBase(RestTranslatableMixin, DictMixin):
    """Base class for job service configuration.

    This class should not be instantiated directly. Instead, use one of its subclasses.

    :keyword endpoint: The endpoint URL.
    :paramtype endpoint: Optional[str]
    :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
    :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
    :keyword port: The port for the endpoint.
    :paramtype port: Optional[int]
    :keyword nodes: Indicates whether the service has to run in all nodes.
    :paramtype nodes: Optional[Literal["all"]]
    :keyword properties: Additional properties to set on the endpoint.
    :paramtype properties: Optional[dict[str, str]]
    :keyword status: The status of the endpoint.
    :paramtype status: Optional[str]
    :keyword kwargs: A dictionary of additional configuration parameters.
    :paramtype kwargs: dict
    """

    def __init__(  # pylint: disable=unused-argument
        self,
        *,
        endpoint: Optional[str] = None,
        type: Optional[  # pylint: disable=redefined-builtin
            Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]
        ] = None,
        nodes: Optional[Literal["all"]] = None,
        status: Optional[str] = None,
        port: Optional[int] = None,
        properties: Optional[Dict[str, str]] = None,
        **kwargs: Dict,
    ) -> None:
        self.endpoint = endpoint
        self.type: Any = type
        self.nodes = nodes
        self.status = status
        self.port = port
        self.properties = properties
        self._validate_nodes()
        self._validate_type_name()

    def _validate_nodes(self) -> None:
        if not self.nodes in ["all", None]:
            msg = f"nodes should be either 'all' or None, but received '{self.nodes}'."
            raise ValidationException(
                message=msg,
                no_personal_data_message=msg,
                target=ErrorTarget.JOB,
                error_category=ErrorCategory.USER_ERROR,
                error_type=ValidationErrorType.INVALID_VALUE,
            )

    def _validate_type_name(self) -> None:
        if self.type and not self.type in JobServiceTypeNames.ENTITY_TO_REST:
            msg = (
                f"type should be one of " f"{JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC}, but received '{self.type}'."
            )
            raise ValidationException(
                message=msg,
                no_personal_data_message=msg,
                target=ErrorTarget.JOB,
                error_category=ErrorCategory.USER_ERROR,
                error_type=ValidationErrorType.INVALID_VALUE,
            )

    def _to_rest_job_service(self, updated_properties: Optional[Dict[str, str]] = None) -> RestJobService:
        return RestJobService(
            endpoint=self.endpoint,
            job_service_type=JobServiceTypeNames.ENTITY_TO_REST.get(self.type, None) if self.type else None,
            nodes=AllNodes() if self.nodes else None,
            status=self.status,
            port=self.port,
            properties=updated_properties if updated_properties else self.properties,
        )

    @classmethod
    def _to_rest_job_services(
        cls,
        services: Optional[Dict],
    ) -> Optional[Dict[str, RestJobService]]:
        if services is None:
            return None

        return {name: service._to_rest_object() for name, service in services.items()}

    @classmethod
    def _from_rest_job_service_object(cls, obj: RestJobService) -> "JobServiceBase":
        return cls(
            endpoint=obj.endpoint,
            type=(
                JobServiceTypeNames.REST_TO_ENTITY.get(obj.job_service_type, None)  # type: ignore[arg-type]
                if obj.job_service_type
                else None
            ),
            nodes="all" if obj.nodes else None,
            status=obj.status,
            port=obj.port,
            # ssh_public_keys=_get_property(obj.properties, "sshPublicKeys"),
            properties=obj.properties,
        )

    @classmethod
    def _from_rest_job_services(cls, services: Dict[str, RestJobService]) -> Dict:
        # """Resolve Dict[str, RestJobService] to Dict[str, Specific JobService]"""
        if services is None:
            return None

        result: dict = {}
        for name, service in services.items():
            if service.job_service_type == JobServiceTypeNames.RestNames.JUPYTER_LAB:
                result[name] = JupyterLabJobService._from_rest_object(service)
            elif service.job_service_type == JobServiceTypeNames.RestNames.SSH:
                result[name] = SshJobService._from_rest_object(service)
            elif service.job_service_type == JobServiceTypeNames.RestNames.TENSOR_BOARD:
                result[name] = TensorBoardJobService._from_rest_object(service)
            elif service.job_service_type == JobServiceTypeNames.RestNames.VS_CODE:
                result[name] = VsCodeJobService._from_rest_object(service)
            else:
                result[name] = JobService._from_rest_object(service)
        return result


[docs] class JobService(JobServiceBase): """Basic job service configuration for backward compatibility. This class is not intended to be used directly. Instead, use one of its subclasses specific to your job type. :keyword endpoint: The endpoint URL. :paramtype endpoint: Optional[str] :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code". :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]] :keyword port: The port for the endpoint. :paramtype port: Optional[int] :keyword nodes: Indicates whether the service has to run in all nodes. :paramtype nodes: Optional[Literal["all"]] :keyword properties: Additional properties to set on the endpoint. :paramtype properties: Optional[dict[str, str]] :keyword status: The status of the endpoint. :paramtype status: Optional[str] :keyword kwargs: A dictionary of additional configuration parameters. :paramtype kwargs: dict """ @classmethod def _from_rest_object(cls, obj: RestJobService) -> "JobService": return cast(JobService, cls._from_rest_job_service_object(obj)) def _to_rest_object(self) -> RestJobService: return self._to_rest_job_service()
[docs] class SshJobService(JobServiceBase): """SSH job service configuration. :ivar type: Specifies the type of job service. Set automatically to "ssh" for this class. :vartype type: str :keyword endpoint: The endpoint URL. :paramtype endpoint: Optional[str] :keyword port: The port for the endpoint. :paramtype port: Optional[int] :keyword nodes: Indicates whether the service has to run in all nodes. :paramtype nodes: Optional[Literal["all"]] :keyword properties: Additional properties to set on the endpoint. :paramtype properties: Optional[dict[str, str]] :keyword status: The status of the endpoint. :paramtype status: Optional[str] :keyword ssh_public_keys: The SSH Public Key to access the job container. :paramtype ssh_public_keys: Optional[str] :keyword kwargs: A dictionary of additional configuration parameters. :paramtype kwargs: dict .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START ssh_job_service_configuration] :end-before: [END ssh_job_service_configuration] :language: python :dedent: 8 :caption: Configuring a SshJobService configuration on a command job. """ def __init__( self, *, endpoint: Optional[str] = None, nodes: Optional[Literal["all"]] = None, status: Optional[str] = None, port: Optional[int] = None, ssh_public_keys: Optional[str] = None, properties: Optional[Dict[str, str]] = None, **kwargs: Any, # pylint: disable=unused-argument ) -> None: super().__init__( endpoint=endpoint, nodes=nodes, status=status, port=port, properties=properties, **kwargs, ) self.type = JobServiceTypeNames.EntityNames.SSH self.ssh_public_keys = ssh_public_keys @classmethod def _from_rest_object(cls, obj: RestJobService) -> "SshJobService": ssh_job_service = cast(SshJobService, cls._from_rest_job_service_object(obj)) ssh_job_service.ssh_public_keys = _get_property(obj.properties, "sshPublicKeys") return ssh_job_service def _to_rest_object(self) -> RestJobService: updated_properties = _append_or_update_properties(self.properties, "sshPublicKeys", self.ssh_public_keys) return self._to_rest_job_service(updated_properties)
[docs] class TensorBoardJobService(JobServiceBase): """TensorBoard job service configuration. :ivar type: Specifies the type of job service. Set automatically to "tensor_board" for this class. :vartype type: str :keyword endpoint: The endpoint URL. :paramtype endpoint: Optional[str] :keyword port: The port for the endpoint. :paramtype port: Optional[int] :keyword nodes: Indicates whether the service has to run in all nodes. :paramtype nodes: Optional[Literal["all"]] :keyword properties: Additional properties to set on the endpoint. :paramtype properties: Optional[dict[str, str]] :keyword status: The status of the endpoint. :paramtype status: Optional[str] :keyword log_dir: The directory path for the log file. :paramtype log_dir: Optional[str] :keyword kwargs: A dictionary of additional configuration parameters. :paramtype kwargs: dict .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START ssh_job_service_configuration] :end-before: [END ssh_job_service_configuration] :language: python :dedent: 8 :caption: Configuring TensorBoardJobService configuration on a command job. """ def __init__( self, *, endpoint: Optional[str] = None, nodes: Optional[Literal["all"]] = None, status: Optional[str] = None, port: Optional[int] = None, log_dir: Optional[str] = None, properties: Optional[Dict[str, str]] = None, **kwargs: Any, # pylint: disable=unused-argument ) -> None: super().__init__( endpoint=endpoint, nodes=nodes, status=status, port=port, properties=properties, **kwargs, ) self.type = JobServiceTypeNames.EntityNames.TENSOR_BOARD self.log_dir = log_dir @classmethod def _from_rest_object(cls, obj: RestJobService) -> "TensorBoardJobService": tensorboard_job_Service = cast(TensorBoardJobService, cls._from_rest_job_service_object(obj)) tensorboard_job_Service.log_dir = _get_property(obj.properties, "logDir") return tensorboard_job_Service def _to_rest_object(self) -> RestJobService: updated_properties = _append_or_update_properties(self.properties, "logDir", self.log_dir) return self._to_rest_job_service(updated_properties)
[docs] class JupyterLabJobService(JobServiceBase): """JupyterLab job service configuration. :ivar type: Specifies the type of job service. Set automatically to "jupyter_lab" for this class. :vartype type: str :keyword endpoint: The endpoint URL. :paramtype endpoint: Optional[str] :keyword port: The port for the endpoint. :paramtype port: Optional[int] :keyword nodes: Indicates whether the service has to run in all nodes. :paramtype nodes: Optional[Literal["all"]] :keyword properties: Additional properties to set on the endpoint. :paramtype properties: Optional[dict[str, str]] :keyword status: The status of the endpoint. :paramtype status: Optional[str] :keyword kwargs: A dictionary of additional configuration parameters. :paramtype kwargs: dict .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START ssh_job_service_configuration] :end-before: [END ssh_job_service_configuration] :language: python :dedent: 8 :caption: Configuring JupyterLabJobService configuration on a command job. """ def __init__( self, *, endpoint: Optional[str] = None, nodes: Optional[Literal["all"]] = None, status: Optional[str] = None, port: Optional[int] = None, properties: Optional[Dict[str, str]] = None, **kwargs: Any, # pylint: disable=unused-argument ) -> None: super().__init__( endpoint=endpoint, nodes=nodes, status=status, port=port, properties=properties, **kwargs, ) self.type = JobServiceTypeNames.EntityNames.JUPYTER_LAB @classmethod def _from_rest_object(cls, obj: RestJobService) -> "JupyterLabJobService": return cast(JupyterLabJobService, cls._from_rest_job_service_object(obj)) def _to_rest_object(self) -> RestJobService: return self._to_rest_job_service()
[docs] class VsCodeJobService(JobServiceBase): """VS Code job service configuration. :ivar type: Specifies the type of job service. Set automatically to "vs_code" for this class. :vartype type: str :keyword endpoint: The endpoint URL. :paramtype endpoint: Optional[str] :keyword port: The port for the endpoint. :paramtype port: Optional[int] :keyword nodes: Indicates whether the service has to run in all nodes. :paramtype nodes: Optional[Literal["all"]] :keyword properties: Additional properties to set on the endpoint. :paramtype properties: Optional[dict[str, str]] :keyword status: The status of the endpoint. :paramtype status: Optional[str] :keyword kwargs: A dictionary of additional configuration parameters. :paramtype kwargs: dict .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START ssh_job_service_configuration] :end-before: [END ssh_job_service_configuration] :language: python :dedent: 8 :caption: Configuring a VsCodeJobService configuration on a command job. """ def __init__( self, *, endpoint: Optional[str] = None, nodes: Optional[Literal["all"]] = None, status: Optional[str] = None, port: Optional[int] = None, properties: Optional[Dict[str, str]] = None, **kwargs: Any, # pylint: disable=unused-argument ) -> None: super().__init__( endpoint=endpoint, nodes=nodes, status=status, port=port, properties=properties, **kwargs, ) self.type = JobServiceTypeNames.EntityNames.VS_CODE @classmethod def _from_rest_object(cls, obj: RestJobService) -> "VsCodeJobService": return cast(VsCodeJobService, cls._from_rest_job_service_object(obj)) def _to_rest_object(self) -> RestJobService: return self._to_rest_job_service()
def _append_or_update_properties( properties: Optional[Dict[str, str]], key: str, value: Optional[str] ) -> Dict[str, str]: if value and not properties: properties = {key: value} if value and properties: properties.update({key: value}) return properties if properties is not None else {} def _get_property(properties: Dict[str, str], key: str) -> Optional[str]: return properties.get(key, None) if properties else None