# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access
import json
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path
# cspell:disable-next-line
from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401 import (
MachineLearningServicesClient as AzureAiAssetsClient042024,
)
# cspell:disable-next-line
from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex
from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType
from azure.ai.ml._scope_dependent_operations import (
OperationConfig,
OperationsContainer,
OperationScope,
_ScopeDependentOperations,
)
from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
from azure.ai.ml._utils._asset_utils import (
_resolve_label_to_asset,
_validate_auto_delete_setting_in_data_output,
_validate_workspace_managed_datastore,
)
from azure.ai.ml._utils._http_utils import HttpPipeline
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml._utils.utils import _get_base_urls_from_discovery_service
from azure.ai.ml.constants._common import AssetTypes, AzureMLResourceType, WorkspaceDiscoveryUrlKey
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.entities import PipelineJob, PipelineJobSettings
from azure.ai.ml.entities._assets import Index
from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration
from azure.ai.ml.entities._indexes import (
AzureAISearchConfig,
GitSource,
IndexDataSource,
LocalSource,
ModelConfiguration,
)
from azure.ai.ml.entities._indexes.data_index_func import index_data as index_data_func
from azure.ai.ml.entities._indexes.entities.data_index import (
CitationRegex,
Data,
DataIndex,
Embedding,
IndexSource,
IndexStore,
)
from azure.ai.ml.entities._indexes.utils._open_ai_utils import build_connection_id, build_open_ai_protocol
from azure.ai.ml.entities._inputs_outputs import Input
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
from azure.ai.ml.operations._datastore_operations import DatastoreOperations
from azure.core.credentials import TokenCredential
ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger
[docs]
class IndexOperations(_ScopeDependentOperations):
"""Represents a client for performing operations on index assets.
You should not instantiate this class directly. Instead, you should create MLClient and use this client via the
property MLClient.index
"""
def __init__(
self,
*,
operation_scope: OperationScope,
operation_config: OperationConfig,
credential: TokenCredential,
datastore_operations: DatastoreOperations,
all_operations: OperationsContainer,
**kwargs: Any,
):
super().__init__(operation_scope, operation_config)
ops_logger.update_filter()
self._credential = credential
# Dataplane service clients are lazily created as they are needed
self.__azure_ai_assets_client: Optional[AzureAiAssetsClient042024] = None
# Kwargs to propagate to dataplane service clients
self._service_client_kwargs: Dict[str, Any] = kwargs.pop("_service_client_kwargs", {})
self._all_operations = all_operations
self._datastore_operation: DatastoreOperations = datastore_operations
self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
# Maps a label to a function which given an asset name,
# returns the asset associated with the label
self._managed_label_resolver: Dict[str, Callable[[str], Index]] = {"latest": self._get_latest_version}
@property
def _azure_ai_assets(self) -> AzureAiAssetsClient042024:
"""Lazily instantiated client for azure ai assets api.
.. note::
Property is lazily instantiated since the api's base url depends on the user's workspace, and
must be fetched remotely.
"""
if self.__azure_ai_assets_client is None:
workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
endpoint = _get_base_urls_from_discovery_service(
workspace_operations, self._operation_scope.workspace_name, self._requests_pipeline
)[WorkspaceDiscoveryUrlKey.API]
self.__azure_ai_assets_client = AzureAiAssetsClient042024(
endpoint=endpoint,
subscription_id=self._operation_scope.subscription_id,
resource_group_name=self._operation_scope.resource_group_name,
workspace_name=self._operation_scope.workspace_name,
credential=self._credential,
**self._service_client_kwargs,
)
return self.__azure_ai_assets_client
[docs]
@monitor_with_activity(ops_logger, "Index.CreateOrUpdate", ActivityType.PUBLICAPI)
def create_or_update(self, index: Index, **kwargs) -> Index:
"""Returns created or updated index asset.
If not already in storage, asset will be uploaded to the workspace's default datastore.
:param index: Index asset object.
:type index: Index
:return: Index asset object.
:rtype: ~azure.ai.ml.entities.Index
"""
if not index.name:
msg = "Must specify a name."
raise ValidationException(
message=msg,
target=ErrorTarget.INDEX,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.MISSING_FIELD,
)
if not index.version:
if not index._auto_increment_version:
msg = "Must specify a version."
raise ValidationException(
message=msg,
target=ErrorTarget.INDEX,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.MISSING_FIELD,
)
next_version = self._azure_ai_assets.indexes.get_next_version(index.name).next_version
if next_version is None:
msg = "Version not specified, could not automatically increment version. Set a version to resolve."
raise ValidationException(
message=msg,
target=ErrorTarget.INDEX,
no_personal_data_message=msg,
error_category=ErrorCategory.SYSTEM_ERROR,
error_type=ValidationErrorType.MISSING_FIELD,
)
index.version = str(next_version)
_ = _check_and_upload_path(
artifact=index,
asset_operations=self,
datastore_name=index.datastore,
artifact_type=ErrorTarget.INDEX,
show_progress=self._show_progress,
)
return Index._from_rest_object(
self._azure_ai_assets.indexes.create_or_update(
name=index.name, version=index.version, body=index._to_rest_object(), **kwargs
)
)
[docs]
@monitor_with_activity(ops_logger, "Index.Get", ActivityType.PUBLICAPI)
def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Index:
"""Returns information about the specified index asset.
:param str name: Name of the index asset.
:keyword Optional[str] version: Version of the index asset. Mutually exclusive with `label`.
:keyword Optional[str] label: The label of the index asset. Mutually exclusive with `version`.
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated.
Details will be provided in the error message.
:return: Index asset object.
:rtype: ~azure.ai.ml.entities.Index
"""
if version and label:
msg = "Cannot specify both version and label."
raise ValidationException(
message=msg,
target=ErrorTarget.INDEX,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
if label:
return _resolve_label_to_asset(self, name, label)
if not version:
msg = "Must provide either version or label."
raise ValidationException(
message=msg,
target=ErrorTarget.INDEX,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.MISSING_FIELD,
)
index_version_resource = self._azure_ai_assets.indexes.get(name=name, version=version, **kwargs)
return Index._from_rest_object(index_version_resource)
def _get_latest_version(self, name: str) -> Index:
return Index._from_rest_object(self._azure_ai_assets.indexes.get_latest(name))
[docs]
@monitor_with_activity(ops_logger, "Index.List", ActivityType.PUBLICAPI)
def list(
self, name: Optional[str] = None, *, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, **kwargs
) -> Iterable[Index]:
"""List all Index assets in workspace.
If an Index is specified by name, list all version of that Index.
:param name: Name of the model.
:type name: Optional[str]
:keyword list_view_type: View type for including/excluding (for example) archived models.
Defaults to :attr:`ListViewType.ACTIVE_ONLY`.
:paramtype list_view_type: ListViewType
:return: An iterator like instance of Index objects
:rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Index]
"""
def cls(rest_indexes: Iterable[RestIndex]) -> List[Index]:
return [Index._from_rest_object(i) for i in rest_indexes]
if name is None:
return self._azure_ai_assets.indexes.list_latest(cls=cls, **kwargs)
return self._azure_ai_assets.indexes.list(name, list_view_type=list_view_type, cls=cls, **kwargs)
[docs]
def build_index(
self,
*,
######## required args ##########
name: str,
embeddings_model_config: ModelConfiguration,
######## chunking information ##########
data_source_citation_url: Optional[str] = None,
tokens_per_chunk: Optional[int] = None,
token_overlap_across_chunks: Optional[int] = None,
input_glob: Optional[str] = None,
######## other generic args ########
document_path_replacement_regex: Optional[str] = None,
######## Azure AI Search index info ########
index_config: Optional[AzureAISearchConfig] = None, # todo better name?
######## data source info ########
input_source: Union[IndexDataSource, str],
input_source_credential: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
) -> Union["Index", "Job"]: # type: ignore[name-defined]
"""Builds an index on the cloud using the Azure AI Resources service.
:keyword name: The name of the index to be created.
:paramtype name: str
:keyword embeddings_model_config: Model config for the embedding model.
:paramtype embeddings_model_config: ~azure.ai.ml.entities._indexes.ModelConfiguration
:keyword data_source_citation_url: The URL of the data source.
:paramtype data_source_citation_url: Optional[str]
:keyword tokens_per_chunk: The size of chunks to be used for indexing.
:paramtype tokens_per_chunk: Optional[int]
:keyword token_overlap_across_chunks: The amount of overlap between chunks.
:paramtype token_overlap_across_chunks: Optional[int]
:keyword input_glob: The glob pattern to be used for indexing.
:paramtype input_glob: Optional[str]
:keyword document_path_replacement_regex: The regex pattern for replacing document paths.
:paramtype document_path_replacement_regex: Optional[str]
:keyword index_config: The configuration for the ACS output.
:paramtype index_config: Optional[~azure.ai.ml.entities._indexes.AzureAISearchConfig]
:keyword input_source: The input source for the index.
:paramtype input_source: Union[~azure.ai.ml.entities._indexes.IndexDataSource, str]
:keyword input_source_credential: The identity to be used for the index.
:paramtype input_source_credential: Optional[Union[~azure.ai.ml.entities.ManagedIdentityConfiguration,
~azure.ai.ml.entities.UserIdentityConfiguration]]
:return: If the `source_input` is a GitSource, returns a created DataIndex Job object.
:rtype: Union[~azure.ai.ml.entities.Index, ~azure.ai.ml.entities.Job]
:raises ValueError: If the `source_input` is not type ~typing.Str or
~azure.ai.ml.entities._indexes.LocalSource.
"""
if document_path_replacement_regex:
document_path_replacement_regex = json.loads(document_path_replacement_regex)
data_index = DataIndex(
name=name,
source=IndexSource(
input_data=Data(
type="uri_folder",
path=".",
),
input_glob=input_glob,
chunk_size=tokens_per_chunk,
chunk_overlap=token_overlap_across_chunks,
citation_url=data_source_citation_url,
citation_url_replacement_regex=(
CitationRegex(
match_pattern=document_path_replacement_regex["match_pattern"], # type: ignore[index]
replacement_pattern=document_path_replacement_regex[
"replacement_pattern" # type: ignore[index]
],
)
if document_path_replacement_regex
else None
),
),
embedding=Embedding(
model=build_open_ai_protocol(
model=embeddings_model_config.model_name, deployment=embeddings_model_config.deployment_name
),
connection=build_connection_id(embeddings_model_config.connection_name, self._operation_scope),
),
index=(
IndexStore(
type="acs",
connection=build_connection_id(index_config.connection_id, self._operation_scope),
name=index_config.index_name,
)
if index_config is not None
else IndexStore(type="faiss")
),
# name is replaced with a unique value each time the job is run
path=f"azureml://datastores/workspaceblobstore/paths/indexes/{name}/{{name}}",
)
if isinstance(input_source, LocalSource):
data_index.source.input_data = Data(
type="uri_folder",
path=input_source.input_data.path,
)
return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential)
if isinstance(input_source, str):
data_index.source.input_data = Data(
type="uri_folder",
path=input_source,
)
return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential)
if isinstance(input_source, GitSource):
from azure.ai.ml import MLClient
ml_registry = MLClient(credential=self._credential, registry_name="azureml")
git_clone_component = ml_registry.components.get("llm_rag_git_clone", label="latest")
# Clone Git Repo and use as input to index_job
@pipeline(default_compute="serverless") # type: ignore[call-overload]
def git_to_index(
git_url,
branch_name="",
git_connection_id="",
):
git_clone = git_clone_component(git_repository=git_url, branch_name=branch_name)
git_clone.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_GIT"] = git_connection_id
index_job = index_data_func(
description=data_index.description,
data_index=data_index,
input_data_override=git_clone.outputs.output_data,
ml_client=MLClient(
subscription_id=self._subscription_id,
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
credential=self._credential,
),
)
# pylint: disable=no-member
return index_job.outputs
git_index_job = git_to_index(
git_url=input_source.url,
branch_name=input_source.branch_name,
git_connection_id=input_source.connection_id,
)
# Ensure repo cloned each run to get latest, comment out to have first clone reused.
git_index_job.settings.force_rerun = True
# Submit the DataIndex Job
return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(git_index_job)
raise ValueError(f"Unsupported input source type {type(input_source)}")
def _create_data_indexing_job(
self,
data_index: DataIndex,
identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
compute: str = "serverless",
serverless_instance_type: Optional[str] = None,
input_data_override: Optional[Input] = None,
submit_job: bool = True,
**kwargs,
) -> PipelineJob:
"""
Returns the data import job that is creating the data asset.
:param data_index: DataIndex object.
:type data_index: azure.ai.ml.entities._dataindex
:param identity: Identity configuration for the job.
:type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]]
:param compute: The compute target to use for the job. Default: "serverless".
:type compute: str
:param serverless_instance_type: The instance type to use for serverless compute.
:type serverless_instance_type: Optional[str]
:param input_data_override: Input data override for the job.
Used to pipe output of step into DataIndex Job in a pipeline.
:type input_data_override: Optional[Input]
:param submit_job: Whether to submit the job to the service. Default: True.
:type submit_job: bool
:return: data import job object.
:rtype: ~azure.ai.ml.entities.PipelineJob.
"""
# pylint: disable=no-member
from azure.ai.ml import MLClient
default_name = "data_index_" + data_index.name if data_index.name is not None else ""
experiment_name = kwargs.pop("experiment_name", None) or default_name
data_index.type = AssetTypes.URI_FOLDER
# avoid specifying auto_delete_setting in job output now
_validate_auto_delete_setting_in_data_output(data_index.auto_delete_setting)
# block customer specified path on managed datastore
data_index.path = _validate_workspace_managed_datastore(data_index.path)
if "${{name}}" not in str(data_index.path) and "{name}" not in str(data_index.path):
data_index.path = str(data_index.path).rstrip("/") + "/${{name}}"
index_job = index_data_func(
description=data_index.description or kwargs.pop("description", None) or default_name,
name=data_index.name or kwargs.pop("name", None),
display_name=kwargs.pop("display_name", None) or default_name,
experiment_name=experiment_name,
compute=compute,
serverless_instance_type=serverless_instance_type,
data_index=data_index,
ml_client=MLClient(
subscription_id=self._subscription_id,
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
credential=self._credential,
),
identity=identity,
input_data_override=input_data_override,
**kwargs,
)
index_pipeline = PipelineJob(
description=index_job.description,
tags=index_job.tags,
name=index_job.name,
display_name=index_job.display_name,
experiment_name=experiment_name,
properties=index_job.properties or {},
settings=PipelineJobSettings(force_rerun=True, default_compute=compute),
jobs={default_name: index_job},
)
index_pipeline.properties["azureml.mlIndexAssetName"] = data_index.name
index_pipeline.properties["azureml.mlIndexAssetKind"] = data_index.index.type
index_pipeline.properties["azureml.mlIndexAssetSource"] = kwargs.pop("mlindex_asset_source", "Data Asset")
if submit_job:
return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(
job=index_pipeline, skip_validation=True, **kwargs
)
return index_pipeline