# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access
import json
import os
import re
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
from marshmallow.exceptions import ValidationError as SchemaValidationError
from azure.ai.ml._artifacts._artifact_utilities import _upload_and_generate_remote_uri
from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes
from azure.ai.ml._exception_helper import log_and_raise_error
from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import BatchJobResource
from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023
from azure.ai.ml._schema._deployment.batch.batch_job import BatchJobSchema
from azure.ai.ml._scope_dependent_operations import (
OperationConfig,
OperationsContainer,
OperationScope,
_ScopeDependentOperations,
)
from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, remove_aml_prefix
from azure.ai.ml._utils._azureml_polling import AzureMLPolling
from azure.ai.ml._utils._endpoint_utils import convert_v1_dataset_to_v2, validate_response
from azure.ai.ml._utils._http_utils import HttpPipeline
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml._utils.utils import (
_get_mfe_base_url_from_discovery_service,
is_private_preview_enabled,
modified_operation_client,
)
from azure.ai.ml.constants._common import (
ARM_ID_FULL_PREFIX,
AZUREML_REGEX_FORMAT,
BASE_PATH_CONTEXT_KEY,
HTTP_PREFIX,
LONG_URI_REGEX_FORMAT,
PARAMS_OVERRIDE_KEY,
SHORT_URI_REGEX_FORMAT,
AssetTypes,
AzureMLResourceType,
InputTypes,
LROConfigurations,
)
from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointYamlFields
from azure.ai.ml.entities import BatchEndpoint, BatchJob
from azure.ai.ml.entities._inputs_outputs import Input
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
from azure.core.credentials import TokenCredential
from azure.core.exceptions import HttpResponseError, ServiceRequestError, ServiceResponseError
from azure.core.paging import ItemPaged
from azure.core.polling import LROPoller
from azure.core.tracing.decorator import distributed_trace
from ._operation_orchestrator import OperationOrchestrator
if TYPE_CHECKING:
from azure.ai.ml.operations import DatastoreOperations
ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger
[docs]
class BatchEndpointOperations(_ScopeDependentOperations):
"""BatchEndpointOperations.
You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
for you and attaches it as an attribute.
:param operation_scope: Scope variables for the operations classes of an MLClient object.
:type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
:param operation_config: Common configuration for operations classes of an MLClient object.
:type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
:param service_client_10_2023: Service client to allow end users to operate on Azure Machine Learning Workspace
resources.
:type service_client_10_2023: ~azure.ai.ml._restclient.v2023_10_01._azure_machine_learning_workspaces.
AzureMachineLearningWorkspaces
:param all_operations: All operations classes of an MLClient object.
:type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
:param credentials: Credential to use for authentication.
:type credentials: ~azure.core.credentials.TokenCredential
"""
def __init__(
self,
operation_scope: OperationScope,
operation_config: OperationConfig,
service_client_10_2023: ServiceClient102023,
all_operations: OperationsContainer,
credentials: Optional[TokenCredential] = None,
**kwargs: Any,
):
super(BatchEndpointOperations, self).__init__(operation_scope, operation_config)
ops_logger.update_filter()
self._batch_operation = service_client_10_2023.batch_endpoints
self._batch_deployment_operation = service_client_10_2023.batch_deployments
self._batch_job_endpoint = kwargs.pop("service_client_09_2020_dataplanepreview").batch_job_endpoint
self._all_operations = all_operations
self._credentials = credentials
self._init_kwargs = kwargs
self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
@property
def _datastore_operations(self) -> "DatastoreOperations":
from azure.ai.ml.operations import DatastoreOperations
return cast(DatastoreOperations, self._all_operations.all_operations[AzureMLResourceType.DATASTORE])
[docs]
@distributed_trace
@monitor_with_activity(ops_logger, "BatchEndpoint.List", ActivityType.PUBLICAPI)
def list(self) -> ItemPaged[BatchEndpoint]:
"""List endpoints of the workspace.
:return: A list of endpoints
:rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchEndpoint]
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START batch_endpoint_operations_list]
:end-before: [END batch_endpoint_operations_list]
:language: python
:dedent: 8
:caption: List example.
"""
return self._batch_operation.list(
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
cls=lambda objs: [BatchEndpoint._from_rest_object(obj) for obj in objs],
**self._init_kwargs,
)
[docs]
@distributed_trace
@monitor_with_activity(ops_logger, "BatchEndpoint.Get", ActivityType.PUBLICAPI)
def get(
self,
name: str,
) -> BatchEndpoint:
"""Get a Endpoint resource.
:param name: Name of the endpoint.
:type name: str
:return: Endpoint object retrieved from the service.
:rtype: ~azure.ai.ml.entities.BatchEndpoint
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START batch_endpoint_operations_get]
:end-before: [END batch_endpoint_operations_get]
:language: python
:dedent: 8
:caption: Get endpoint example.
"""
# first get the endpoint
endpoint = self._batch_operation.get(
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
endpoint_name=name,
**self._init_kwargs,
)
endpoint_data = BatchEndpoint._from_rest_object(endpoint)
return endpoint_data
[docs]
@distributed_trace
@monitor_with_activity(ops_logger, "BatchEndpoint.BeginDelete", ActivityType.PUBLICAPI)
def begin_delete(self, name: str) -> LROPoller[None]:
"""Delete a batch Endpoint.
:param name: Name of the batch endpoint.
:type name: str
:return: A poller to track the operation status.
:rtype: ~azure.core.polling.LROPoller[None]
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START batch_endpoint_operations_delete]
:end-before: [END batch_endpoint_operations_delete]
:language: python
:dedent: 8
:caption: Delete endpoint example.
"""
path_format_arguments = {
"endpointName": name,
"resourceGroupName": self._resource_group_name,
"workspaceName": self._workspace_name,
}
delete_poller = self._batch_operation.begin_delete(
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
endpoint_name=name,
polling=AzureMLPolling(
LROConfigurations.POLL_INTERVAL,
path_format_arguments=path_format_arguments,
**self._init_kwargs,
),
polling_interval=LROConfigurations.POLL_INTERVAL,
**self._init_kwargs,
)
return delete_poller
[docs]
@distributed_trace
@monitor_with_activity(ops_logger, "BatchEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
def begin_create_or_update(self, endpoint: BatchEndpoint) -> LROPoller[BatchEndpoint]:
"""Create or update a batch endpoint.
:param endpoint: The endpoint entity.
:type endpoint: ~azure.ai.ml.entities.BatchEndpoint
:return: A poller to track the operation status.
:rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.BatchEndpoint]
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START batch_endpoint_operations_create_or_update]
:end-before: [END batch_endpoint_operations_create_or_update]
:language: python
:dedent: 8
:caption: Create endpoint example.
"""
try:
location = self._get_workspace_location()
endpoint_resource = endpoint._to_rest_batch_endpoint(location=location)
poller = self._batch_operation.begin_create_or_update(
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
endpoint_name=endpoint.name,
body=endpoint_resource,
polling=True,
**self._init_kwargs,
cls=lambda response, deserialized, headers: BatchEndpoint._from_rest_object(deserialized),
)
return poller
except Exception as ex:
if isinstance(ex, (ValidationException, SchemaValidationError)):
log_and_raise_error(ex)
raise ex
[docs]
@distributed_trace
@monitor_with_activity(ops_logger, "BatchEndpoint.Invoke", ActivityType.PUBLICAPI)
def invoke( # pylint: disable=too-many-statements
self,
endpoint_name: str,
*,
deployment_name: Optional[str] = None,
inputs: Optional[Dict[str, Input]] = None,
**kwargs: Any,
) -> BatchJob:
"""Invokes the batch endpoint with the provided payload.
:param endpoint_name: The endpoint name.
:type endpoint_name: str
:keyword deployment_name: (Optional) The name of a specific deployment to invoke. This is optional.
By default requests are routed to any of the deployments according to the traffic rules.
:paramtype deployment_name: str
:keyword inputs: (Optional) A dictionary of existing data asset, public uri file or folder
to use with the deployment
:paramtype inputs: Dict[str, Input]
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if deployment cannot be successfully validated.
Details will be provided in the error message.
:raises ~azure.ai.ml.exceptions.AssetException: Raised if BatchEndpoint assets
(e.g. Data, Code, Model, Environment) cannot be successfully validated.
Details will be provided in the error message.
:raises ~azure.ai.ml.exceptions.ModelException: Raised if BatchEndpoint model cannot be successfully validated.
Details will be provided in the error message.
:raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
:return: The invoked batch deployment job.
:rtype: ~azure.ai.ml.entities.BatchJob
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START batch_endpoint_operations_invoke]
:end-before: [END batch_endpoint_operations_invoke]
:language: python
:dedent: 8
:caption: Invoke endpoint example.
"""
outputs = kwargs.get("outputs", None)
job_name = kwargs.get("job_name", None)
params_override = kwargs.get("params_override", None) or []
experiment_name = kwargs.get("experiment_name", None)
input = kwargs.get("input", None) # pylint: disable=redefined-builtin
# Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538
if deployment_name:
self._validate_deployment_name(endpoint_name, deployment_name)
if input and isinstance(input, Input):
if HTTP_PREFIX not in input.path:
self._resolve_input(input, os.getcwd())
# MFE expects a dictionary as input_data that's why we are using
# "UriFolder" or "UriFile" as keys depending on the input type
if input.type == "uri_folder":
params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFolder": input}})
elif input.type == "uri_file":
params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFile": input}})
else:
msg = (
"Unsupported input type please use a dictionary of either a path on the datastore, public URI, "
"a registered data asset, or a local folder path."
)
raise ValidationException(
message=msg,
target=ErrorTarget.BATCH_ENDPOINT,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
elif inputs:
for key, input_data in inputs.items():
if (
isinstance(input_data, Input)
and input_data.type
not in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING]
and HTTP_PREFIX not in input_data.path
):
self._resolve_input(input_data, os.getcwd())
params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: inputs})
properties = {}
if outputs:
params_override.append({EndpointYamlFields.BATCH_JOB_OUTPUT_DATA: outputs})
if job_name:
params_override.append({EndpointYamlFields.BATCH_JOB_NAME: job_name})
if experiment_name:
properties["experimentName"] = experiment_name
if properties:
params_override.append({EndpointYamlFields.BATCH_JOB_PROPERTIES: properties})
# Batch job doesn't have a python class, loading a rest object using params override
context = {
BASE_PATH_CONTEXT_KEY: Path(".").parent,
PARAMS_OVERRIDE_KEY: params_override,
}
batch_job = BatchJobSchema(context=context).load(data={})
# update output datastore to arm id if needed
# TODO: Unify datastore name -> arm id logic, TASK: 1104172
request = {}
if (
batch_job.output_dataset
and batch_job.output_dataset.datastore_id
and (not is_ARM_id_for_resource(batch_job.output_dataset.datastore_id))
):
v2_dataset_dictionary = convert_v1_dataset_to_v2(batch_job.output_dataset, batch_job.output_file_name)
batch_job.output_dataset = None
batch_job.output_file_name = None
request = BatchJobResource(properties=batch_job).serialize()
request["properties"]["outputData"] = v2_dataset_dictionary
else:
request = BatchJobResource(properties=batch_job).serialize()
endpoint = self._batch_operation.get(
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
endpoint_name=endpoint_name,
**self._init_kwargs,
)
headers = EndpointInvokeFields.DEFAULT_HEADER
ml_audience_scopes = _resource_to_scopes(_get_aml_resource_id_from_metadata())
module_logger.debug("ml_audience_scopes used: `%s`\n", ml_audience_scopes)
key = self._credentials.get_token(*ml_audience_scopes).token if self._credentials is not None else ""
headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}"
headers[EndpointInvokeFields.REPEATABILITY_REQUEST_ID] = str(uuid.uuid4())
if deployment_name:
headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name
retry_attempts = 0
while retry_attempts < 5:
try:
response = self._requests_pipeline.post(
endpoint.properties.scoring_uri,
json=request,
headers=headers,
)
except (ServiceRequestError, ServiceResponseError):
retry_attempts += 1
continue
break
if retry_attempts == 5:
retry_msg = "Max retry attempts reached while trying to connect to server. Please check connection and invoke again." # pylint: disable=line-too-long
raise MlException(message=retry_msg, no_personal_data_message=retry_msg, target=ErrorTarget.BATCH_ENDPOINT)
validate_response(response)
batch_job = json.loads(response.text())
return BatchJobResource.deserialize(batch_job)
[docs]
@distributed_trace
@monitor_with_activity(ops_logger, "BatchEndpoint.ListJobs", ActivityType.PUBLICAPI)
def list_jobs(self, endpoint_name: str) -> ItemPaged[BatchJob]:
"""List jobs under the provided batch endpoint deployment. This is only valid for batch endpoint.
:param endpoint_name: The endpoint name
:type endpoint_name: str
:return: List of jobs
:rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchJob]
.. admonition:: Example:
.. literalinclude:: ../samples/ml_samples_misc.py
:start-after: [START batch_endpoint_operations_list_jobs]
:end-before: [END batch_endpoint_operations_list_jobs]
:language: python
:dedent: 8
:caption: List jobs example.
"""
workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
mfe_base_uri = _get_mfe_base_url_from_discovery_service(
workspace_operations, self._workspace_name, self._requests_pipeline
)
with modified_operation_client(self._batch_job_endpoint, mfe_base_uri):
result = self._batch_job_endpoint.list(
endpoint_name=endpoint_name,
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
**self._init_kwargs,
)
# This is necessary as the paged result need to be resolved inside the context manager
return list(result)
def _get_workspace_location(self) -> str:
return str(
self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
)
def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None:
deployments_list = self._batch_deployment_operation.list(
endpoint_name=endpoint_name,
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
cls=lambda objs: [obj.name for obj in objs],
**self._init_kwargs,
)
if deployments_list:
if deployment_name not in deployments_list:
msg = f"Deployment name {deployment_name} not found for this endpoint"
raise ValidationException(
message=msg.format(deployment_name),
no_personal_data_message=msg.format("[deployment_name]"),
target=ErrorTarget.DEPLOYMENT,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
)
else:
msg = "No deployment exists for this endpoint"
raise ValidationException(
message=msg,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
target=ErrorTarget.DEPLOYMENT,
error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
)
def _resolve_input(self, entry: Input, base_path: str) -> None:
# We should not verify anything that is not of type Input
if not isinstance(entry, Input):
return
# Input path should not be empty
if not entry.path:
msg = "Input path can't be empty for batch endpoint invoke"
raise MlException(message=msg, no_personal_data_message=msg)
if entry.type in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING]:
return
try:
if entry.path.startswith(ARM_ID_FULL_PREFIX):
if not is_ARM_id_for_resource(entry.path, AzureMLResourceType.DATA):
raise ValidationException(
message="Invalid input path",
target=ErrorTarget.BATCH_ENDPOINT,
no_personal_data_message="Invalid input path",
error_type=ValidationErrorType.INVALID_VALUE,
)
elif os.path.isabs(entry.path): # absolute local path, upload, transform to remote url
if entry.type == AssetTypes.URI_FOLDER and not os.path.isdir(entry.path):
raise ValidationException(
message="There is no folder on target path: {}".format(entry.path),
target=ErrorTarget.BATCH_ENDPOINT,
no_personal_data_message="There is no folder on target path",
error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
)
if entry.type == AssetTypes.URI_FILE and not os.path.isfile(entry.path):
raise ValidationException(
message="There is no file on target path: {}".format(entry.path),
target=ErrorTarget.BATCH_ENDPOINT,
no_personal_data_message="There is no file on target path",
error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
)
# absolute local path
entry.path = _upload_and_generate_remote_uri(
self._operation_scope,
self._datastore_operations,
entry.path,
)
if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"):
entry.path = entry.path + "/"
elif ":" in entry.path or "@" in entry.path: # Check registered file or folder datastore
# If we receive a datastore path in long/short form we don't need
# to get the arm asset id
if re.match(SHORT_URI_REGEX_FORMAT, entry.path) or re.match(LONG_URI_REGEX_FORMAT, entry.path):
return
if is_private_preview_enabled() and re.match(AZUREML_REGEX_FORMAT, entry.path):
return
asset_type = AzureMLResourceType.DATA
entry.path = remove_aml_prefix(entry.path)
orchestrator = OperationOrchestrator(
self._all_operations, self._operation_scope, self._operation_config
)
entry.path = orchestrator.get_asset_arm_id(entry.path, asset_type)
else: # relative local path, upload, transform to remote url
local_path = Path(base_path, entry.path).resolve()
entry.path = _upload_and_generate_remote_uri(
self._operation_scope,
self._datastore_operations,
local_path,
)
if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"):
entry.path = entry.path + "/"
except (MlException, HttpResponseError) as e:
raise e
except Exception as e:
raise ValidationException(
message=f"Supported input path value are: path on the datastore, public URI, "
"a registered data asset, or a local folder path.\n"
f"Met {type(e)}:\n{e}",
target=ErrorTarget.BATCH_ENDPOINT,
no_personal_data_message="Supported input path value are: path on the datastore, "
"public URI, a registered data asset, or a local folder path.",
error=e,
error_type=ValidationErrorType.INVALID_VALUE,
) from e