Source code for azure.ai.ml.entities._job.distribution

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from typing import Any, Dict, Optional, Union

from azure.ai.ml._restclient.v2023_04_01_preview.models import (
    DistributionConfiguration as RestDistributionConfiguration,
)
from azure.ai.ml._restclient.v2023_04_01_preview.models import DistributionType as RestDistributionType
from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi as RestMpi
from azure.ai.ml._restclient.v2023_04_01_preview.models import PyTorch as RestPyTorch
from azure.ai.ml._restclient.v2023_04_01_preview.models import Ray as RestRay
from azure.ai.ml._restclient.v2023_04_01_preview.models import TensorFlow as RestTensorFlow
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants import DistributionType
from azure.ai.ml.entities._mixins import RestTranslatableMixin

SDK_TO_REST = {
    DistributionType.MPI: RestDistributionType.MPI,
    DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW,
    DistributionType.PYTORCH: RestDistributionType.PY_TORCH,
    DistributionType.RAY: RestDistributionType.RAY,
}


class DistributionConfiguration(RestTranslatableMixin):
    """Distribution configuration for a component or job.

    This class is not meant to be instantiated directly. Instead, use one of its subclasses.
    """

    def __init__(self, **kwargs: Any) -> None:
        self.type: Any = None

    @classmethod
    def _from_rest_object(
        cls, obj: Optional[Union[RestDistributionConfiguration, Dict]]
    ) -> Optional["DistributionConfiguration"]:
        """Constructs a DistributionConfiguration object from a REST object

        This function works for distribution property of a Job object and of a Component object()

        Distribution of Job when returned by MFE, is a RestDistributionConfiguration

        Distribution of Component when returned by MFE, is a Dict.
        e.g. {'type': 'Mpi', 'process_count_per_instance': '1'}

        So in the job distribution case, we need to call as_dict() first and get type from "distribution_type" property.
        In the componenet case, we need to extract type from key "type"


        :param obj: The object to translate
        :type obj: Optional[Union[RestDistributionConfiguration, Dict]]
        :return: The distribution configuration
        :rtype: DistributionConfiguration
        """
        if obj is None:
            return None

        if isinstance(obj, dict):
            data = obj
        else:
            data = obj.as_dict()

        type_str = data.pop("distribution_type", None) or data.pop("type", None)
        klass = DISTRIBUTION_TYPE_MAP[type_str.lower()]
        res: DistributionConfiguration = klass(**data)
        return res

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, DistributionConfiguration):
            return NotImplemented
        res: bool = self._to_rest_object() == other._to_rest_object()
        return res


[docs] class MpiDistribution(DistributionConfiguration): """MPI distribution configuration. :keyword process_count_per_instance: The number of processes per node. :paramtype process_count_per_instance: Optional[int] :ivar type: Specifies the type of distribution. Set automatically to "mpi" for this class. :vartype type: str .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START mpi_distribution_configuration] :end-before: [END mpi_distribution_configuration] :language: python :dedent: 8 :caption: Configuring a CommandComponent with an MpiDistribution. """ def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) self.type = DistributionType.MPI self.process_count_per_instance = process_count_per_instance def _to_rest_object(self) -> RestMpi: return RestMpi(process_count_per_instance=self.process_count_per_instance)
[docs] class PyTorchDistribution(DistributionConfiguration): """PyTorch distribution configuration. :keyword process_count_per_instance: The number of processes per node. :paramtype process_count_per_instance: Optional[int] :ivar type: Specifies the type of distribution. Set automatically to "pytorch" for this class. :vartype type: str .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START pytorch_distribution_configuration] :end-before: [END pytorch_distribution_configuration] :language: python :dedent: 8 :caption: Configuring a CommandComponent with a PyTorchDistribution. """ def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) self.type = DistributionType.PYTORCH self.process_count_per_instance = process_count_per_instance def _to_rest_object(self) -> RestPyTorch: return RestPyTorch(process_count_per_instance=self.process_count_per_instance)
[docs] class TensorFlowDistribution(DistributionConfiguration): """TensorFlow distribution configuration. :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType :keyword parameter_server_count: The number of parameter server tasks. Defaults to 0. :paramtype parameter_server_count: Optional[int] :keyword worker_count: The number of workers. Defaults to the instance count. :paramtype worker_count: Optional[int] :ivar parameter_server_count: Number of parameter server tasks. :vartype parameter_server_count: int :ivar worker_count: Number of workers. If not specified, will default to the instance count. :vartype worker_count: int :ivar type: Specifies the type of distribution. Set automatically to "tensorflow" for this class. :vartype type: str .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_misc.py :start-after: [START tensorflow_distribution_configuration] :end-before: [END tensorflow_distribution_configuration] :language: python :dedent: 8 :caption: Configuring a CommandComponent with a TensorFlowDistribution. """ def __init__( self, *, parameter_server_count: Optional[int] = 0, worker_count: Optional[int] = None, **kwargs: Any ) -> None: super().__init__(**kwargs) self.type = DistributionType.TENSORFLOW self.parameter_server_count = parameter_server_count self.worker_count = worker_count def _to_rest_object(self) -> RestTensorFlow: return RestTensorFlow(parameter_server_count=self.parameter_server_count, worker_count=self.worker_count)
[docs] @experimental class RayDistribution(DistributionConfiguration): """Ray distribution configuration. :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType :ivar port: The port of the head ray process. :vartype port: int :ivar address: The address of Ray head node. :vartype address: str :ivar include_dashboard: Provide this argument to start the Ray dashboard GUI. :vartype include_dashboard: bool :ivar dashboard_port: The port to bind the dashboard server to. :vartype dashboard_port: int :ivar head_node_additional_args: Additional arguments passed to ray start in head node. :vartype head_node_additional_args: str :ivar worker_node_additional_args: Additional arguments passed to ray start in worker node. :vartype worker_node_additional_args: str :ivar type: Specifies the type of distribution. Set automatically to "Ray" for this class. :vartype type: str """ def __init__( self, *, port: Optional[int] = None, address: Optional[str] = None, include_dashboard: Optional[bool] = None, dashboard_port: Optional[int] = None, head_node_additional_args: Optional[str] = None, worker_node_additional_args: Optional[str] = None, **kwargs: Any ): super().__init__(**kwargs) self.type = DistributionType.RAY self.port = port self.address = address self.include_dashboard = include_dashboard self.dashboard_port = dashboard_port self.head_node_additional_args = head_node_additional_args self.worker_node_additional_args = worker_node_additional_args def _to_rest_object(self) -> RestRay: return RestRay( port=self.port, address=self.address, include_dashboard=self.include_dashboard, dashboard_port=self.dashboard_port, head_node_additional_args=self.head_node_additional_args, worker_node_additional_args=self.worker_node_additional_args, )
DISTRIBUTION_TYPE_MAP = { DistributionType.MPI: MpiDistribution, DistributionType.TENSORFLOW: TensorFlowDistribution, DistributionType.PYTORCH: PyTorchDistribution, DistributionType.RAY: RayDistribution, }