# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os
from typing import Any, Dict, List, Optional, Union, cast
from marshmallow import Schema
from azure.ai.ml._schema.component.command_component import CommandComponentSchema
from azure.ai.ml.constants._common import COMPONENT_TYPE
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.entities._assets import Environment
from azure.ai.ml.entities._job.distribution import (
DistributionConfiguration,
MpiDistribution,
PyTorchDistribution,
RayDistribution,
TensorFlowDistribution,
)
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.parameterized_command import ParameterizedCommand
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
from ..._restclient.v2022_10_01.models import ComponentVersion
from ..._schema import PathAwareSchema
from ..._utils.utils import get_all_data_binding_expressions, parse_args_description_from_docstring
from .._util import convert_ordered_dict_to_dict, validate_attribute_type
from .._validation import MutableValidationResult
from ._additional_includes import AdditionalIncludesMixin
from .component import Component
# pylint: disable=protected-access
[docs]
class CommandComponent(Component, ParameterizedCommand, AdditionalIncludesMixin):
"""Command component version, used to define a Command Component or Job.
:keyword name: The name of the Command job or component.
:paramtype name: Optional[str]
:keyword version: The version of the Command job or component.
:paramtype version: Optional[str]
:keyword description: The description of the component. Defaults to None.
:paramtype description: Optional[str]
:keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
:paramtype tags: Optional[dict]
:keyword display_name: The display name of the component.
:paramtype display_name: Optional[str]
:keyword command: The command to be executed.
:paramtype command: Optional[str]
:keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing
to a remote location.
:type code: Optional[str]
:keyword environment: The environment that the job will run in.
:paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
:keyword distribution: The configuration for distributed jobs. Defaults to None.
:paramtype distribution: Optional[Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]]
:keyword resources: The compute resource configuration for the command.
:paramtype resources: Optional[~azure.ai.ml.entities.JobResourceConfiguration]
:keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None.
:paramtype inputs: Optional[dict[str, Union[
~azure.ai.ml.Input,
str,
bool,
int,
float,
Enum,
]]]
:keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None.
:paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
:keyword instance_count: The number of instances or nodes to be used by the compute target. Defaults to 1.
:paramtype instance_count: Optional[int]
:keyword is_deterministic: Specifies whether the Command will return the same output given the same input.
Defaults to True. When True, if a Command (component) is deterministic and has been run before in the
current workspace with the same input and settings, it will reuse results from a previous submitted job
when used as a node or step in a pipeline. In that scenario, no compute resources will be used.
:paramtype is_deterministic: Optional[bool]
:keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None.
:paramtype additional_includes: Optional[List[str]]
:keyword properties: The job property dictionary. Defaults to None.
:paramtype properties: Optional[dict[str, str]]
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if CommandComponent cannot be successfully validated.
Details will be provided in the error message.
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_command_configurations.py
:start-after: [START command_component_definition]
:end-before: [END command_component_definition]
:language: python
:dedent: 8
:caption: Creating a CommandComponent.
"""
def __init__(
self,
*,
name: Optional[str] = None,
version: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[Dict] = None,
display_name: Optional[str] = None,
command: Optional[str] = None,
code: Optional[Union[str, os.PathLike]] = None,
environment: Optional[Union[str, Environment]] = None,
distribution: Optional[
Union[
Dict,
MpiDistribution,
TensorFlowDistribution,
PyTorchDistribution,
RayDistribution,
DistributionConfiguration,
]
] = None,
resources: Optional[JobResourceConfiguration] = None,
inputs: Optional[Dict] = None,
outputs: Optional[Dict] = None,
instance_count: Optional[int] = None, # promoted property from resources.instance_count
is_deterministic: bool = True,
additional_includes: Optional[List] = None,
properties: Optional[Dict] = None,
**kwargs: Any,
) -> None:
# validate init params are valid type
validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
kwargs[COMPONENT_TYPE] = NodeType.COMMAND
# Component backend doesn't support environment_variables yet,
# this is to support the case of CommandComponent being the trial of
# a SweepJob, where environment_variables is stored as part of trial
environment_variables = kwargs.pop("environment_variables", None)
super().__init__(
name=name,
version=version,
description=description,
tags=tags,
display_name=display_name,
inputs=inputs,
outputs=outputs,
is_deterministic=is_deterministic,
properties=properties,
**kwargs,
)
# No validation on value passed here because in pipeline job, required code&environment maybe absent
# and fill in later with job defaults.
self.command = command
self.code = code
self.environment_variables = environment_variables
self.environment = environment
self.resources = resources # type: ignore[assignment]
self.distribution = distribution
# check mutual exclusivity of promoted properties
if self.resources is not None and instance_count is not None:
msg = "instance_count and resources are mutually exclusive"
raise ValidationException(
message=msg,
target=ErrorTarget.COMPONENT,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
)
self.instance_count = instance_count
self.additional_includes = additional_includes or []
def _to_ordered_dict_for_yaml_dump(self) -> Dict:
"""Dump the component content into a sorted yaml string.
:return: The ordered dict
:rtype: Dict
"""
obj: dict = super()._to_ordered_dict_for_yaml_dump()
# dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value
if self.code and isinstance(self.code, str):
obj["code"] = self.code
return obj
@property
def instance_count(self) -> Optional[int]:
"""The number of instances or nodes to be used by the compute target.
:return: The number of instances or nodes.
:rtype: int
"""
return self.resources.instance_count if self.resources and not isinstance(self.resources, dict) else None
@instance_count.setter
def instance_count(self, value: int) -> None:
"""Sets the number of instances or nodes to be used by the compute target.
:param value: The number of instances of nodes to be used by the compute target. Defaults to 1.
:type value: int
"""
if not value:
return
if not self.resources:
self.resources = JobResourceConfiguration(instance_count=value)
else:
if not isinstance(self.resources, dict):
self.resources.instance_count = value
@classmethod
def _attr_type_map(cls) -> dict:
return {
"environment": (str, Environment),
"environment_variables": dict,
"resources": (dict, JobResourceConfiguration),
"code": (str, os.PathLike),
}
def _to_dict(self) -> Dict:
return cast(
dict, convert_ordered_dict_to_dict({**self._other_parameter, **super(CommandComponent, self)._to_dict()})
)
@classmethod
def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict:
# put it here as distribution is shared by some components, e.g. command
distribution = obj.properties.component_spec.pop("distribution", None)
init_kwargs: dict = super()._from_rest_object_to_init_params(obj)
if distribution:
init_kwargs["distribution"] = DistributionConfiguration._from_rest_object(distribution)
return init_kwargs
def _get_environment_id(self) -> Union[str, None]:
# Return environment id of environment
# handle case when environment is defined inline
if isinstance(self.environment, Environment):
_id: Optional[str] = self.environment.id
return _id
return self.environment
# region SchemaValidatableMixin
@classmethod
def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
return CommandComponentSchema(context=context)
def _customized_validate(self) -> MutableValidationResult:
validation_result = super(CommandComponent, self)._customized_validate()
self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result)
validation_result.merge_with(self._validate_command())
validation_result.merge_with(self._validate_early_available_output())
return validation_result
def _validate_command(self) -> MutableValidationResult:
validation_result = self._create_empty_validation_result()
# command
if self.command:
invalid_expressions = []
for data_binding_expression in get_all_data_binding_expressions(self.command, is_singular=False):
if not self._is_valid_data_binding_expression(data_binding_expression):
invalid_expressions.append(data_binding_expression)
if invalid_expressions:
validation_result.append_error(
yaml_path="command",
message="Invalid data binding expression: {}".format(", ".join(invalid_expressions)),
)
return validation_result
def _validate_early_available_output(self) -> MutableValidationResult:
validation_result = self._create_empty_validation_result()
for name, output in self.outputs.items():
if output.early_available is True and output._is_primitive_type is not True:
msg = (
f"Early available output {name!r} requires output is primitive type, "
f"got {output._is_primitive_type!r}."
)
validation_result.append_error(message=msg, yaml_path=f"outputs.{name}")
return validation_result
def _is_valid_data_binding_expression(self, data_binding_expression: str) -> bool:
current_obj: Any = self
for item in data_binding_expression.split("."):
if hasattr(current_obj, item):
current_obj = getattr(current_obj, item)
else:
try:
current_obj = current_obj[item]
except Exception: # pylint: disable=W0718
return False
return True
# endregion
@classmethod
def _parse_args_description_from_docstring(cls, docstring: str) -> Dict:
res: dict = parse_args_description_from_docstring(docstring)
return res
def __str__(self) -> str:
try:
toYaml: str = self._to_yaml()
return toYaml
except BaseException: # pylint: disable=W0718
toStr: str = super(CommandComponent, self).__str__()
return toStr