Source code for azure.ai.ml.entities._job.parallel.parallel_job

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData
from azure.ai.ml._schema.job.parallel_job import ParallelJobSchema
from azure.ai.ml._utils.utils import is_data_binding_expression
from azure.ai.ml.constants import JobType
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
from azure.ai.ml.entities._credentials import (
    AmlTokenConfiguration,
    ManagedIdentityConfiguration,
    UserIdentityConfiguration,
)
from azure.ai.ml.entities._inputs_outputs import Input, Output
from azure.ai.ml.entities._util import load_from_dict
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException

from ..job import Job
from ..job_io_mixin import JobIOMixin
from .parameterized_parallel import ParameterizedParallel

# avoid circular import error
if TYPE_CHECKING:
    from azure.ai.ml.entities._builders import Parallel
    from azure.ai.ml.entities._component.parallel_component import ParallelComponent

module_logger = logging.getLogger(__name__)


[docs] class ParallelJob(Job, ParameterizedParallel, JobIOMixin): """Parallel job. :param name: Name of the job. :type name: str :param version: Version of the job. :type version: str :param id: Global id of the resource, Azure Resource Manager ID. :type id: str :param type: Type of the job, supported is 'parallel'. :type type: str :param description: Description of the job. :type description: str :param tags: Internal use only. :type tags: dict :param properties: Internal use only. :type properties: dict :param display_name: Display name of the job. :type display_name: str :param retry_settings: parallel job run failed retry :type retry_settings: BatchRetrySettings :param logging_level: A string of the logging level name :type logging_level: str :param max_concurrency_per_instance: The max parallellism that each compute instance has. :type max_concurrency_per_instance: int :param error_threshold: The number of item processing failures should be ignored. :type error_threshold: int :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored. :type mini_batch_error_threshold: int :keyword identity: The identity that the job will use while running on compute. :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]] :param task: The parallel task. :type task: ParallelTask :param mini_batch_size: The mini batch size. :type mini_batch_size: str :param partition_keys: The partition keys. :type partition_keys: list :param input_data: The input data. :type input_data: str :param inputs: Inputs of the job. :type inputs: dict :param outputs: Outputs of the job. :type outputs: dict """ def __init__( self, *, inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, outputs: Optional[Dict[str, Output]] = None, identity: Optional[ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict] ] = None, **kwargs: Any, ): kwargs[TYPE] = JobType.PARALLEL super().__init__(**kwargs) self.inputs = inputs # type: ignore[assignment] self.outputs = outputs # type: ignore[assignment] self.identity = identity def _to_dict(self) -> Dict: res: dict = ParallelJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) return res def _to_rest_object(self) -> None: pass @classmethod def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "ParallelJob": loaded_data = load_from_dict(ParallelJobSchema, data, context, additional_message, **kwargs) return ParallelJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) @classmethod def _load_from_rest(cls, obj: JobBaseData) -> None: pass def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "ParallelComponent": """Translate a parallel job to component job. :param context: Context of parallel job YAML file. :type context: dict :return: Translated parallel component. :rtype: ParallelComponent """ from azure.ai.ml.entities._component.parallel_component import ParallelComponent pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} # Create anonymous parallel component with default version as 1 init_kwargs = {} for key in [ "mini_batch_size", "partition_keys", "logging_level", "max_concurrency_per_instance", "error_threshold", "mini_batch_error_threshold", "retry_settings", "resources", ]: value = getattr(self, key) from azure.ai.ml.entities import BatchRetrySettings, JobResourceConfiguration values_to_check: List = [] if key == "retry_settings" and isinstance(value, BatchRetrySettings): values_to_check = [value.max_retries, value.timeout] elif key == "resources" and isinstance(value, JobResourceConfiguration): values_to_check = [ value.locations, value.instance_count, value.instance_type, value.shm_size, value.max_instance_count, value.docker_args, ] else: values_to_check = [value] # note that component level attributes can not be data binding expressions # so filter out data binding expression properties here; # they will still take effect at node level according to _to_node if any( map( lambda x: is_data_binding_expression(x, binding_prefix=["parent", "inputs"], is_singular=False) or is_data_binding_expression(x, binding_prefix=["inputs"], is_singular=False), values_to_check, ) ): continue init_kwargs[key] = getattr(self, key) return ParallelComponent( base_path=context[BASE_PATH_CONTEXT_KEY], # for parallel_job.task, all attributes for this are string for now so data binding expression is allowed # in SDK level naturally, but not sure if such component is valid. leave the validation to service side. task=self.task, inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), input_data=self.input_data, # keep them if no data binding expression detected to keep the behavior of to_component **init_kwargs, ) def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Parallel": """Translate a parallel job to a pipeline node. :param context: Context of parallel job YAML file. :type context: dict :return: Translated parallel component. :rtype: Parallel """ from azure.ai.ml.entities._builders import Parallel component = self._to_component(context, **kwargs) return Parallel( component=component, compute=self.compute, # Need to supply the inputs with double curly. inputs=self.inputs, # type: ignore[arg-type] outputs=self.outputs, # type: ignore[arg-type] mini_batch_size=self.mini_batch_size, partition_keys=self.partition_keys, input_data=self.input_data, # task will be inherited from component & base_path will be set correctly. retry_settings=self.retry_settings, logging_level=self.logging_level, max_concurrency_per_instance=self.max_concurrency_per_instance, error_threshold=self.error_threshold, mini_batch_error_threshold=self.mini_batch_error_threshold, environment_variables=self.environment_variables, properties=self.properties, identity=self.identity, resources=self.resources if self.resources and not isinstance(self.resources, dict) else None, ) def _validate(self) -> None: if self.name is None: msg = "Job name is required" raise ValidationException( message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.MISSING_FIELD, ) if self.compute is None: msg = "compute is required" raise ValidationException( message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.MISSING_FIELD, ) if self.task is None: msg = "task is required" raise ValidationException( message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.MISSING_FIELD, )