Source code for azure.ai.projects.models._patch

# pylint: disable=too-many-lines
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Customize generated code here.

Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
import asyncio
import base64
import datetime
import inspect
import itertools
import json
import logging
import math
import re
from abc import ABC, abstractmethod
from typing import (
    Any,
    AsyncIterator,
    Awaitable,
    Callable,
    Dict,
    Generic,
    Iterator,
    List,
    Mapping,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
    overload,
)

from azure.core.credentials import AccessToken, TokenCredential
from azure.core.credentials_async import AsyncTokenCredential

from ._enums import AgentStreamEvent, ConnectionType, MessageRole
from ._models import (
    AzureAISearchResource,
    AzureAISearchToolDefinition,
    AzureFunctionDefinition,
    AzureFunctionStorageQueue,
    AzureFunctionToolDefinition,
    AzureFunctionBinding,
    BingGroundingToolDefinition,
    CodeInterpreterToolDefinition,
    CodeInterpreterToolResource,
    FileSearchToolDefinition,
    FileSearchToolResource,
    FunctionDefinition,
    FunctionToolDefinition,
    GetConnectionResponse,
    IndexResource,
    MessageImageFileContent,
    MessageTextContent,
    MessageTextFileCitationAnnotation,
    MessageTextFilePathAnnotation,
    MicrosoftFabricToolDefinition,
    OpenApiAuthDetails,
    OpenApiToolDefinition,
    OpenApiFunctionDefinition,
    RequiredFunctionToolCall,
    RunStep,
    RunStepDeltaChunk,
    SharepointToolDefinition,
    SubmitToolOutputsAction,
    ThreadRun,
    ToolConnection,
    ToolConnectionList,
    ToolDefinition,
    ToolResources,
    MessageDeltaTextContent,
    VectorStoreDataSource,
)

from ._models import MessageDeltaChunk as MessageDeltaChunkGenerated
from ._models import ThreadMessage as ThreadMessageGenerated
from ._models import OpenAIPageableListOfThreadMessage as OpenAIPageableListOfThreadMessageGenerated
from ._models import MessageAttachment as MessageAttachmentGenerated

from .. import _types


logger = logging.getLogger(__name__)

StreamEventData = Union["MessageDeltaChunk", "ThreadMessage", ThreadRun, RunStep, str]


def _filter_parameters(model_class: Type, parameters: Dict[str, Any]) -> Dict[str, Any]:
    """
    Remove the parameters, non present in class public fields; return shallow copy of a dictionary.

    **Note:** Classes inherited from the model check that the parameters are present
    in the list of attributes and if they are not, the error is being raised. This check may not
    be relevant for classes, not inherited from azure.ai.projects._model_base.Model.
    :param Type model_class: The class of model to be used.
    :param parameters: The parsed dictionary with parameters.
    :type parameters: Union[str, Dict[str, Any]]
    :return: The dictionary with all invalid parameters removed.
    :rtype: Dict[str, Any]
    """
    new_params = {}
    valid_parameters = set(
        filter(
            lambda x: not x.startswith("_") and hasattr(model_class.__dict__[x], "_type"), model_class.__dict__.keys()
        )
    )
    for k in filter(lambda x: x in valid_parameters, parameters.keys()):
        new_params[k] = parameters[k]
    return new_params


def _safe_instantiate(
    model_class: Type, parameters: Union[str, Dict[str, Any]], *, generated_class: Optional[Type] = None
) -> StreamEventData:
    """
    Instantiate class with the set of parameters from the server.

    :param Type model_class: The class of model to be used.
    :param parameters: The parsed dictionary with parameters.
    :type parameters: Union[str, Dict[str, Any]]
    :keyword Optional[Type] generated_class: The optional generated type.
    :return: The class of model_class type if parameters is a dictionary, or the parameters themselves otherwise.
    :rtype: Any
    """
    if not generated_class:
        generated_class = model_class
    if not isinstance(parameters, dict):
        return parameters
    return cast(StreamEventData, model_class(**_filter_parameters(generated_class, parameters)))


def _parse_event(event_data_str: str) -> Tuple[str, StreamEventData]:
    event_lines = event_data_str.strip().split("\n")
    event_type: Optional[str] = None
    event_data = ""
    event_obj: StreamEventData
    for line in event_lines:
        if line.startswith("event:"):
            event_type = line.split(":", 1)[1].strip()
        elif line.startswith("data:"):
            event_data = line.split(":", 1)[1].strip()

    if not event_type:
        raise ValueError("Event type not specified in the event data.")

    try:
        parsed_data: Union[str, Dict[str, StreamEventData]] = cast(Dict[str, StreamEventData], json.loads(event_data))
    except json.JSONDecodeError:
        parsed_data = event_data

    # Workaround for service bug: Rename 'expires_at' to 'expired_at'
    if event_type.startswith("thread.run.step") and isinstance(parsed_data, dict) and "expires_at" in parsed_data:
        parsed_data["expired_at"] = parsed_data.pop("expires_at")

    # Map to the appropriate class instance
    if event_type in {
        AgentStreamEvent.THREAD_RUN_CREATED.value,
        AgentStreamEvent.THREAD_RUN_QUEUED.value,
        AgentStreamEvent.THREAD_RUN_INCOMPLETE.value,
        AgentStreamEvent.THREAD_RUN_IN_PROGRESS.value,
        AgentStreamEvent.THREAD_RUN_REQUIRES_ACTION.value,
        AgentStreamEvent.THREAD_RUN_COMPLETED.value,
        AgentStreamEvent.THREAD_RUN_FAILED.value,
        AgentStreamEvent.THREAD_RUN_CANCELLING.value,
        AgentStreamEvent.THREAD_RUN_CANCELLED.value,
        AgentStreamEvent.THREAD_RUN_EXPIRED.value,
    }:
        event_obj = _safe_instantiate(ThreadRun, parsed_data)
    elif event_type in {
        AgentStreamEvent.THREAD_RUN_STEP_CREATED.value,
        AgentStreamEvent.THREAD_RUN_STEP_IN_PROGRESS.value,
        AgentStreamEvent.THREAD_RUN_STEP_COMPLETED.value,
        AgentStreamEvent.THREAD_RUN_STEP_FAILED.value,
        AgentStreamEvent.THREAD_RUN_STEP_CANCELLED.value,
        AgentStreamEvent.THREAD_RUN_STEP_EXPIRED.value,
    }:
        event_obj = _safe_instantiate(RunStep, parsed_data)
    elif event_type in {
        AgentStreamEvent.THREAD_MESSAGE_CREATED.value,
        AgentStreamEvent.THREAD_MESSAGE_IN_PROGRESS.value,
        AgentStreamEvent.THREAD_MESSAGE_COMPLETED.value,
        AgentStreamEvent.THREAD_MESSAGE_INCOMPLETE.value,
    }:
        event_obj = _safe_instantiate(ThreadMessage, parsed_data, generated_class=ThreadMessageGenerated)
    elif event_type == AgentStreamEvent.THREAD_MESSAGE_DELTA.value:
        event_obj = _safe_instantiate(MessageDeltaChunk, parsed_data, generated_class=MessageDeltaChunkGenerated)

    elif event_type == AgentStreamEvent.THREAD_RUN_STEP_DELTA.value:
        event_obj = _safe_instantiate(RunStepDeltaChunk, parsed_data)
    else:
        event_obj = str(parsed_data)

    return event_type, event_obj


[docs] class ConnectionProperties: """The properties of a single connection. :ivar id: A unique identifier for the connection. :vartype id: str :ivar name: The friendly name of the connection. :vartype name: str :ivar authentication_type: The authentication type used by the connection. :vartype authentication_type: ~azure.ai.projects.models._models.AuthenticationType :ivar connection_type: The connection type . :vartype connection_type: ~azure.ai.projects.models._models.ConnectionType :ivar endpoint_url: The endpoint URL associated with this connection :vartype endpoint_url: str :ivar key: The api-key to be used when accessing the connection. :vartype key: str :ivar token_credential: The TokenCredential to be used when accessing the connection. :vartype token_credential: ~azure.core.credentials.TokenCredential """ def __init__( self, *, connection: GetConnectionResponse, token_credential: Union[TokenCredential, AsyncTokenCredential, None] = None, ) -> None: self.id = connection.id self.name = connection.name self.authentication_type = connection.properties.auth_type self.connection_type = cast(ConnectionType, connection.properties.category) self.endpoint_url = ( connection.properties.target[:-1] if connection.properties.target.endswith("/") else connection.properties.target ) self.key: Optional[str] = None if hasattr(connection.properties, "credentials"): if hasattr(connection.properties.credentials, "key"): # type: ignore self.key = connection.properties.credentials.key # type: ignore self.token_credential = token_credential
[docs] def to_evaluator_model_config( self, deployment_name: str, api_version: str, *, include_credentials: bool = False ) -> Dict[str, str]: """Get model configuration to be used with evaluators, from connection. :param deployment_name: Deployment name to build model configuration. :type deployment_name: str :param api_version: API version used by model deployment. :type api_version: str :keyword include_credentials: Include credentials in the model configuration. If set to True, the model configuration will have the key field set to the actual key value. If set to False, the model configuration will have the key field set to the connection id. To get the secret, connection.get method should be called with include_credentials set to True. :paramtype include_credentials: bool :returns: Model configuration dictionary. :rtype: Dict[str, str] """ connection_type = self.connection_type.value if self.connection_type.value == ConnectionType.AZURE_OPEN_AI: connection_type = "azure_openai" if self.authentication_type == "ApiKey": model_config = { "azure_deployment": deployment_name, "azure_endpoint": self.endpoint_url, "type": connection_type, "api_version": api_version, "api_key": self.key if include_credentials and self.key else f"{self.id}/credentials/key", } else: model_config = { "azure_deployment": deployment_name, "azure_endpoint": self.endpoint_url, "type": self.connection_type, "api_version": api_version, } return model_config
def __str__(self): out = "{\n" out += f' "name": "{self.name}",\n' out += f' "id": "{self.id}",\n' out += f' "authentication_type": "{self.authentication_type}",\n' out += f' "connection_type": "{self.connection_type}",\n' out += f' "endpoint_url": "{self.endpoint_url}",\n' if self.key: out += ' "key": "REDACTED"\n' else: out += ' "key": null\n' if self.token_credential: out += ' "token_credential": "REDACTED"\n' else: out += ' "token_credential": null\n' out += "}\n" return out
# TODO: Look into adding an async version of this class
[docs] class SASTokenCredential(TokenCredential): def __init__( self, *, sas_token: str, credential: TokenCredential, subscription_id: str, resource_group_name: str, project_name: str, connection_name: str, ): self._sas_token = sas_token self._credential = credential self._subscription_id = subscription_id self._resource_group_name = resource_group_name self._project_name = project_name self._connection_name = connection_name self._expires_on = SASTokenCredential._get_expiration_date_from_token(self._sas_token) logger.debug("[SASTokenCredential.__init__] Exit. Given token expires on %s.", self._expires_on) @classmethod def _get_expiration_date_from_token(cls, jwt_token: str) -> datetime.datetime: payload = jwt_token.split(".")[1] padded_payload = payload + "=" * (4 - len(payload) % 4) # Add padding if necessary decoded_bytes = base64.urlsafe_b64decode(padded_payload) decoded_str = decoded_bytes.decode("utf-8") decoded_payload = json.loads(decoded_str) expiration_date = decoded_payload.get("exp") return datetime.datetime.fromtimestamp(expiration_date, datetime.timezone.utc) def _refresh_token(self) -> None: logger.debug("[SASTokenCredential._refresh_token] Enter") from azure.ai.projects import AIProjectClient project_client = AIProjectClient( credential=self._credential, # Since we are only going to use the "connections" operations, we don't need to supply an endpoint. # http://management.azure.com is hard coded in the SDK. endpoint="not-needed", subscription_id=self._subscription_id, resource_group_name=self._resource_group_name, project_name=self._project_name, ) connection = project_client.connections.get(connection_name=self._connection_name, include_credentials=True) self._sas_token = "" if connection is not None and connection.token_credential is not None: sas_credential = cast(SASTokenCredential, connection.token_credential) self._sas_token = sas_credential._sas_token # pylint: disable=protected-access self._expires_on = SASTokenCredential._get_expiration_date_from_token(self._sas_token) logger.debug("[SASTokenCredential._refresh_token] Exit. New token expires on %s.", self._expires_on)
[docs] def get_token( self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. :param str scopes: The type of access needed. :keyword str claims: Additional claims required in the token, such as those returned in a resource provider's claims challenge following an authorization failure. :keyword str tenant_id: Optional tenant to include in the token request. :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) for the requested token. Defaults to False. :rtype: AccessToken :return: An AccessToken instance containing the token string and its expiration time in Unix time. """ logger.debug("SASTokenCredential.get_token] Enter") if self._expires_on < datetime.datetime.now(datetime.timezone.utc): self._refresh_token() return AccessToken(self._sas_token, math.floor(self._expires_on.timestamp()))
# Define type_map to translate Python type annotations to JSON Schema types type_map = { "str": "string", "int": "integer", "float": "number", "bool": "boolean", "NoneType": "null", "list": "array", "dict": "object", } def _map_type(annotation) -> Dict[str, Any]: # pylint: disable=too-many-return-statements if annotation == inspect.Parameter.empty: return {"type": "string"} # Default type if annotation is missing origin = get_origin(annotation) if origin in {list, List}: args = get_args(annotation) item_type = args[0] if args else str return {"type": "array", "items": _map_type(item_type)} if origin in {dict, Dict}: return {"type": "object"} if origin is Union: args = get_args(annotation) # If Union contains None, it is an optional parameter if type(None) in args: # If Union contains only one non-None type, it is a nullable parameter non_none_args = [arg for arg in args if arg is not type(None)] if len(non_none_args) == 1: schema = _map_type(non_none_args[0]) if "type" in schema: if isinstance(schema["type"], str): schema["type"] = [schema["type"], "null"] elif "null" not in schema["type"]: schema["type"].append("null") else: schema["type"] = ["null"] return schema # If Union contains multiple types, it is a oneOf parameter return {"oneOf": [_map_type(arg) for arg in args]} if isinstance(annotation, type): schema_type = type_map.get(annotation.__name__, "string") return {"type": schema_type} return {"type": "string"} # Fallback to "string" if type is unrecognized def is_optional(annotation) -> bool: origin = get_origin(annotation) if origin is Union: args = get_args(annotation) return type(None) in args return False
[docs] class MessageDeltaChunk(MessageDeltaChunkGenerated): @property def text(self) -> str: """Get the text content of the delta chunk. :rtype: str """ if not self.delta or not self.delta.content: return "" return "".join( content_part.text.value or "" for content_part in self.delta.content if isinstance(content_part, MessageDeltaTextContent) and content_part.text )
[docs] class ThreadMessage(ThreadMessageGenerated): @property def text_messages(self) -> List[MessageTextContent]: """Returns all text message contents in the messages. :rtype: List[MessageTextContent] """ if not self.content: return [] return [content for content in self.content if isinstance(content, MessageTextContent)] @property def image_contents(self) -> List[MessageImageFileContent]: """Returns all image file contents from image message contents in the messages. :rtype: List[MessageImageFileContent] """ if not self.content: return [] return [content for content in self.content if isinstance(content, MessageImageFileContent)] @property def file_citation_annotations(self) -> List[MessageTextFileCitationAnnotation]: """Returns all file citation annotations from text message annotations in the messages. :rtype: List[MessageTextFileCitationAnnotation] """ if not self.content: return [] return [ annotation for content in self.content if isinstance(content, MessageTextContent) for annotation in content.text.annotations if isinstance(annotation, MessageTextFileCitationAnnotation) ] @property def file_path_annotations(self) -> List[MessageTextFilePathAnnotation]: """Returns all file path annotations from text message annotations in the messages. :rtype: List[MessageTextFilePathAnnotation] """ if not self.content: return [] return [ annotation for content in self.content if isinstance(content, MessageTextContent) for annotation in content.text.annotations if isinstance(annotation, MessageTextFilePathAnnotation) ]
[docs] class MessageAttachment(MessageAttachmentGenerated): @overload def __init__( self, *, tools: List["FileSearchToolDefinition"], file_id: Optional[str] = None, data_source: Optional["VectorStoreDataSource"] = None, ) -> None: ... @overload def __init__( self, *, tools: List["CodeInterpreterToolDefinition"], file_id: Optional[str] = None, data_source: Optional["VectorStoreDataSource"] = None, ) -> None: ... @overload def __init__( self, *, tools: List["_types.MessageAttachmentToolDefinition"], file_id: Optional[str] = None, data_source: Optional["VectorStoreDataSource"] = None, ) -> None: ... @overload def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs)
ToolDefinitionT = TypeVar("ToolDefinitionT", bound=ToolDefinition) ToolT = TypeVar("ToolT", bound="Tool")
[docs] class Tool(ABC, Generic[ToolDefinitionT]): """ An abstract class representing a tool that can be used by an agent. """ @property @abstractmethod def definitions(self) -> List[ToolDefinitionT]: """Get the tool definitions.""" @property @abstractmethod def resources(self) -> ToolResources: """Get the tool resources."""
[docs] @abstractmethod def execute(self, tool_call: Any) -> Any: """ Execute the tool with the provided tool call. :param Any tool_call: The tool call to execute. :return: The output of the tool operations. """
class BaseFunctionTool(Tool[FunctionToolDefinition]): """ A tool that executes user-defined functions. """ def __init__(self, functions: Set[Callable[..., Any]]): """ Initialize FunctionTool with a set of functions. :param functions: A set of function objects. """ self._functions = self._create_function_dict(functions) self._definitions = self._build_function_definitions(self._functions) def add_functions(self, extra_functions: Set[Callable[..., Any]]) -> None: """ Add more functions into this FunctionTool’s existing function set. If a function with the same name already exists, it is overwritten. :param extra_functions: A set of additional functions to be added to the existing function set. Functions are defined as callables and may have any number of arguments and return types. :type extra_functions: Set[Callable[..., Any]] """ # Convert the existing dictionary of { name: function } back into a set existing_functions = set(self._functions.values()) # Merge old + new combined = existing_functions.union(extra_functions) # Rebuild state self._functions = self._create_function_dict(combined) self._definitions = self._build_function_definitions(self._functions) def _create_function_dict(self, functions: Set[Callable[..., Any]]) -> Dict[str, Callable[..., Any]]: return {func.__name__: func for func in functions} def _build_function_definitions(self, functions: Dict[str, Any]) -> List[FunctionToolDefinition]: specs: List[FunctionToolDefinition] = [] # Flexible regex to capture ':param <name>: <description>' param_pattern = re.compile( r""" ^\s* # Optional leading whitespace :param # Literal ':param' \s+ # At least one whitespace character (?P<name>[^:\s\(\)]+) # Parameter name (no spaces, colons, or parentheses) (?:\s*\(\s*(?P<type>[^)]+?)\s*\))? # Optional type in parentheses, allowing internal spaces \s*:\s* # Colon ':' surrounded by optional whitespace (?P<description>.+) # Description (rest of the line) """, re.VERBOSE, ) for name, func in functions.items(): sig = inspect.signature(func) params = sig.parameters docstring = inspect.getdoc(func) or "" description = docstring.split("\n", maxsplit=1)[0] if docstring else "No description" param_descriptions = {} for line in docstring.splitlines(): line = line.strip() match = param_pattern.match(line) if match: groups = match.groupdict() param_name = groups.get("name") param_desc = groups.get("description") param_desc = param_desc.strip() if param_desc else "No description" param_descriptions[param_name] = param_desc.strip() properties = {} required = [] for param_name, param in params.items(): param_type_info = _map_type(param.annotation) param_description = param_descriptions.get(param_name, "No description") properties[param_name] = {**param_type_info, "description": param_description} # If the parameter has no default value and is not optional, add it to the required list if param.default is inspect.Parameter.empty and not is_optional(param.annotation): required.append(param_name) function_def = FunctionDefinition( name=name, description=description, parameters={"type": "object", "properties": properties, "required": required}, ) tool_def = FunctionToolDefinition(function=function_def) specs.append(tool_def) return specs def _get_func_and_args(self, tool_call: RequiredFunctionToolCall) -> Tuple[Any, Dict[str, Any]]: function_name = tool_call.function.name arguments = tool_call.function.arguments if function_name not in self._functions: logging.error("Function '%s' not found.", function_name) raise ValueError(f"Function '{function_name}' not found.") function = self._functions[function_name] try: parsed_arguments = json.loads(arguments) except json.JSONDecodeError as e: logging.error("Invalid JSON arguments for function '%s': %s", function_name, e) raise ValueError(f"Invalid JSON arguments: {e}") from e if not isinstance(parsed_arguments, dict): logging.error("Arguments must be a JSON object for function '%s'.", function_name) raise TypeError("Arguments must be a JSON object.") return function, parsed_arguments @property def definitions(self) -> List[FunctionToolDefinition]: """ Get the function definitions. :return: A list of function definitions. :rtype: List[ToolDefinition] """ return self._definitions @property def resources(self) -> ToolResources: """ Get the tool resources for the agent. :return: An empty ToolResources as FunctionTool doesn't have specific resources. :rtype: ToolResources """ return ToolResources()
[docs] class FunctionTool(BaseFunctionTool):
[docs] def execute(self, tool_call: RequiredFunctionToolCall) -> Any: function, parsed_arguments = self._get_func_and_args(tool_call) try: return function(**parsed_arguments) if parsed_arguments else function() except TypeError as e: error_message = f"Error executing function '{tool_call.function.name}': {e}" logging.error(error_message) # Return error message as JSON string back to agent in order to make possible self # correction to the function call return json.dumps({"error": error_message})
[docs] class AsyncFunctionTool(BaseFunctionTool):
[docs] async def execute(self, tool_call: RequiredFunctionToolCall) -> Any: # pylint: disable=invalid-overridden-method function, parsed_arguments = self._get_func_and_args(tool_call) try: if inspect.iscoroutinefunction(function): return await function(**parsed_arguments) if parsed_arguments else await function() return function(**parsed_arguments) if parsed_arguments else function() except TypeError as e: error_message = f"Error executing function '{tool_call.function.name}': {e}" logging.error(error_message) # Return error message as JSON string back to agent in order to make possible self correction # to the function call return json.dumps({"error": error_message})
[docs] class AzureAISearchTool(Tool[AzureAISearchToolDefinition]): """ A tool that searches for information using Azure AI Search. """ def __init__(self, index_connection_id: str, index_name: str): self.index_list = [IndexResource(index_connection_id=index_connection_id, index_name=index_name)] @property def definitions(self) -> List[AzureAISearchToolDefinition]: """ Get the Azure AI search tool definitions. :return: A list of tool definitions. :rtype: List[ToolDefinition] """ return [AzureAISearchToolDefinition()] @property def resources(self) -> ToolResources: """ Get the Azure AI search resources. :return: ToolResources populated with azure_ai_search associated resources. :rtype: ToolResources """ return ToolResources(azure_ai_search=AzureAISearchResource(index_list=self.index_list))
[docs] def execute(self, tool_call: Any): """ AI Search tool does not execute client-side. :param Any tool_call: The tool call to execute. """
[docs] class OpenApiTool(Tool[OpenApiToolDefinition]): """ A tool that retrieves information using OpenAPI specs. Initialized with an initial API definition (name, description, spec, auth), this class also supports adding and removing additional API definitions dynamically. """ def __init__(self, name: str, description: str, spec: Any, auth: OpenApiAuthDetails): """ Constructor initializes the tool with a primary API definition. :param name: The name of the API. :param description: The API description. :param spec: The API specification. :param auth: Authentication details for the API. :type auth: OpenApiAuthDetails """ self._default_auth = auth self._definitions: List[OpenApiToolDefinition] = [ OpenApiToolDefinition( openapi=OpenApiFunctionDefinition(name=name, description=description, spec=spec, auth=auth) ) ] @property def definitions(self) -> List[OpenApiToolDefinition]: """ Get the list of all API definitions for the tool. :return: A list of OpenAPI tool definitions. :rtype: List[ToolDefinition] """ return self._definitions
[docs] def add_definition(self, name: str, description: str, spec: Any, auth: Optional[OpenApiAuthDetails] = None) -> None: """ Adds a new API definition dynamically. Raises a ValueError if a definition with the same name already exists. :param name: The name of the API. :type name: str :param description: The description of the API. :type description: str :param spec: The API specification. :type spec: Any :param auth: Optional authentication details for this particular API definition. If not provided, the tool's default authentication details will be used. :type auth: Optional[OpenApiAuthDetails] :raises ValueError: If a definition with the same name exists. """ # Check if a definition with the same name exists. if any(definition.openapi.name == name for definition in self._definitions): raise ValueError(f"Definition '{name}' already exists and cannot be added again.") # Use provided auth if specified, otherwise use default auth_to_use = auth if auth is not None else self._default_auth new_definition = OpenApiToolDefinition( openapi=OpenApiFunctionDefinition(name=name, description=description, spec=spec, auth=auth_to_use) ) self._definitions.append(new_definition)
[docs] def remove_definition(self, name: str) -> None: """ Removes an API definition based on its name. :param name: The name of the API definition to remove. :type name: str :raises ValueError: If the definition with the specified name does not exist. """ for definition in self._definitions: if definition.openapi.name == name: self._definitions.remove(definition) logging.info("Definition '%s' removed. Total definitions: %d.", name, len(self._definitions)) return raise ValueError(f"Definition with the name '{name}' does not exist.")
@property def resources(self) -> ToolResources: """ Get the tool resources for the agent. :return: An empty ToolResources as OpenApiTool doesn't have specific resources. :rtype: ToolResources """ return ToolResources()
[docs] def execute(self, tool_call: Any) -> None: """ OpenApiTool does not execute client-side. :param Any tool_call: The tool call to execute. :type tool_call: Any """
[docs] class AzureFunctionTool(Tool[AzureFunctionToolDefinition]): """ A tool that is used to inform agent about available the Azure function. :param name: The azure function name. :param description: The azure function description. :param parameters: The description of function parameters. :param input_queue: Input queue used, by azure function. :param output_queue: Output queue used, by azure function. """ def __init__( self, name: str, description: str, parameters: Dict[str, Any], input_queue: AzureFunctionStorageQueue, output_queue: AzureFunctionStorageQueue, ) -> None: self._definitions = [ AzureFunctionToolDefinition( azure_function=AzureFunctionDefinition( function=FunctionDefinition( name=name, description=description, parameters=parameters, ), input_binding=AzureFunctionBinding(storage_queue=input_queue), output_binding=AzureFunctionBinding(storage_queue=output_queue), ) ) ] @property def definitions(self) -> List[AzureFunctionToolDefinition]: """ Get the Azure AI search tool definitions. :rtype: List[ToolDefinition] """ return self._definitions @property def resources(self) -> ToolResources: """ Get the Azure AI search resources. :rtype: ToolResources """ return ToolResources()
[docs] def execute(self, tool_call: Any) -> Any: pass
class ConnectionTool(Tool[ToolDefinitionT]): """ A tool that requires connection ids. Used as base class for Bing Grounding, Sharepoint, and Microsoft Fabric """ def __init__(self, connection_id: str): """ Initialize ConnectionTool with a connection_id. :param connection_id: Connection ID used by tool. All connection tools allow only one connection. """ self.connection_ids = [ToolConnection(connection_id=connection_id)] @property def resources(self) -> ToolResources: """ Get the connection tool resources. :rtype: ToolResources """ return ToolResources() def execute(self, tool_call: Any) -> Any: pass
[docs] class BingGroundingTool(ConnectionTool[BingGroundingToolDefinition]): """ A tool that searches for information using Bing. """ @property def definitions(self) -> List[BingGroundingToolDefinition]: """ Get the Bing grounding tool definitions. :rtype: List[ToolDefinition] """ return [BingGroundingToolDefinition(bing_grounding=ToolConnectionList(connection_list=self.connection_ids))]
[docs] class FabricTool(ConnectionTool[MicrosoftFabricToolDefinition]): """ A tool that searches for information using Microsoft Fabric. """ @property def definitions(self) -> List[MicrosoftFabricToolDefinition]: """ Get the Microsoft Fabric tool definitions. :rtype: List[ToolDefinition] """ return [MicrosoftFabricToolDefinition(fabric_aiskill=ToolConnectionList(connection_list=self.connection_ids))]
[docs] class SharepointTool(ConnectionTool[SharepointToolDefinition]): """ A tool that searches for information using Sharepoint. """ @property def definitions(self) -> List[SharepointToolDefinition]: """ Get the Sharepoint tool definitions. :rtype: List[ToolDefinition] """ return [SharepointToolDefinition(sharepoint_grounding=ToolConnectionList(connection_list=self.connection_ids))]
[docs] class FileSearchTool(Tool[FileSearchToolDefinition]): """ A tool that searches for uploaded file information from the created vector stores. :param vector_store_ids: A list of vector store IDs to search for files. :type vector_store_ids: list[str] """ def __init__(self, vector_store_ids: Optional[List[str]] = None): if vector_store_ids is None: self.vector_store_ids = set() else: self.vector_store_ids = set(vector_store_ids)
[docs] def add_vector_store(self, store_id: str) -> None: """ Add a vector store ID to the list of vector stores to search for files. :param store_id: The ID of the vector store to search for files. :type store_id: str """ self.vector_store_ids.add(store_id)
[docs] def remove_vector_store(self, store_id: str) -> None: """ Remove a vector store ID from the list of vector stores to search for files. :param store_id: The ID of the vector store to remove. :type store_id: str """ self.vector_store_ids.remove(store_id)
@property def definitions(self) -> List[FileSearchToolDefinition]: """ Get the file search tool definitions. :rtype: List[ToolDefinition] """ return [FileSearchToolDefinition()] @property def resources(self) -> ToolResources: """ Get the file search resources. :rtype: ToolResources """ return ToolResources(file_search=FileSearchToolResource(vector_store_ids=list(self.vector_store_ids)))
[docs] def execute(self, tool_call: Any) -> Any: pass
[docs] class CodeInterpreterTool(Tool[CodeInterpreterToolDefinition]): """ A tool that interprets code files uploaded to the agent. :param file_ids: A list of file IDs to interpret. :type file_ids: list[str] """ def __init__(self, file_ids: Optional[List[str]] = None): if file_ids is None: self.file_ids = set() else: self.file_ids = set(file_ids)
[docs] def add_file(self, file_id: str) -> None: """ Add a file ID to the list of files to interpret. :param file_id: The ID of the file to interpret. :type file_id: str """ self.file_ids.add(file_id)
[docs] def remove_file(self, file_id: str) -> None: """ Remove a file ID from the list of files to interpret. :param file_id: The ID of the file to remove. :type file_id: str """ self.file_ids.remove(file_id)
@property def definitions(self) -> List[CodeInterpreterToolDefinition]: """ Get the code interpreter tool definitions. :rtype: List[ToolDefinition] """ return [CodeInterpreterToolDefinition()] @property def resources(self) -> ToolResources: """ Get the code interpreter resources. :rtype: ToolResources """ if not self.file_ids: return ToolResources() return ToolResources(code_interpreter=CodeInterpreterToolResource(file_ids=list(self.file_ids)))
[docs] def execute(self, tool_call: Any) -> Any: pass
class BaseToolSet: """ Abstract class for a collection of tools that can be used by an agent. """ def __init__(self) -> None: self._tools: List[Tool] = [] def validate_tool_type(self, tool: Tool) -> None: pass def add(self, tool: Tool): """ Add a tool to the tool set. :param Tool tool: The tool to add. :raises ValueError: If a tool of the same type already exists. """ self.validate_tool_type(tool) if any(isinstance(existing_tool, type(tool)) for existing_tool in self._tools): raise ValueError("Tool of type {type(tool).__name__} already exists in the ToolSet.") self._tools.append(tool) def remove(self, tool_type: Type[Tool]) -> None: """ Remove a tool of the specified type from the tool set. :param Type[Tool] tool_type: The type of tool to remove. :raises ValueError: If a tool of the specified type is not found. """ for i, tool in enumerate(self._tools): if isinstance(tool, tool_type): del self._tools[i] logging.info("Tool of type %s removed from the ToolSet.", tool_type.__name__) return raise ValueError(f"Tool of type {tool_type.__name__} not found in the ToolSet.") @property def definitions(self) -> List[ToolDefinition]: """ Get the definitions for all tools in the tool set. :rtype: List[ToolDefinition] """ tools = [] for tool in self._tools: tools.extend(tool.definitions) return tools @property def resources(self) -> ToolResources: """ Get the resources for all tools in the tool set. :rtype: ToolResources """ tool_resources: Dict[str, Any] = {} for tool in self._tools: resources = tool.resources for key, value in resources.items(): if key in tool_resources: if isinstance(tool_resources[key], dict) and isinstance(value, dict): tool_resources[key].update(value) else: tool_resources[key] = value return self._create_tool_resources_from_dict(tool_resources) def _create_tool_resources_from_dict(self, resources: Dict[str, Any]) -> ToolResources: """ Safely converts a dictionary into a ToolResources instance. :param resources: A dictionary of tool resources. Should be a mapping accepted by ~azure.ai.projects.models.AzureAISearchResource :type resources: Dict[str, Any] :return: A ToolResources instance. :rtype: ToolResources """ try: return ToolResources(**resources) except TypeError as e: logging.error("Error creating ToolResources: %s", e) raise ValueError("Invalid resources for ToolResources.") from e def get_definitions_and_resources(self) -> Dict[str, Any]: """ Get the definitions and resources for all tools in the tool set. :return: A dictionary containing the tool resources and definitions. :rtype: Dict[str, Any] """ return { "tool_resources": self.resources, "tools": self.definitions, } def get_tool(self, tool_type: Type[ToolT]) -> ToolT: """ Get a tool of the specified type from the tool set. :param Type[Tool] tool_type: The type of tool to get. :return: The tool of the specified type. :rtype: Tool :raises ValueError: If a tool of the specified type is not found. """ for tool in self._tools: if isinstance(tool, tool_type): return cast(ToolT, tool) raise ValueError(f"Tool of type {tool_type.__name__} not found in the ToolSet.")
[docs] class ToolSet(BaseToolSet): """ A collection of tools that can be used by an synchronize agent. """
[docs] def validate_tool_type(self, tool: Tool) -> None: """ Validate the type of the tool. :param Tool tool: The type of the tool to validate. :raises ValueError: If the tool type is not a subclass of Tool. """ if isinstance(tool, AsyncFunctionTool): raise ValueError( "AsyncFunctionTool is not supported in ToolSet. " + "To use async functions, use AsyncToolSet and agents operations in azure.ai.projects.aio." )
[docs] def execute_tool_calls(self, tool_calls: List[Any]) -> Any: """ Execute a tool of the specified type with the provided tool calls. :param List[Any] tool_calls: A list of tool calls to execute. :return: The output of the tool operations. :rtype: Any """ tool_outputs = [] for tool_call in tool_calls: try: if tool_call.type == "function": tool = self.get_tool(FunctionTool) output = tool.execute(tool_call) tool_output = { "tool_call_id": tool_call.id, "output": output, } tool_outputs.append(tool_output) except Exception as e: # pylint: disable=broad-exception-caught logging.error("Failed to execute tool call %s: %s", tool_call, e) return tool_outputs
[docs] class AsyncToolSet(BaseToolSet): """ A collection of tools that can be used by an asynchronous agent. """
[docs] def validate_tool_type(self, tool: Tool) -> None: """ Validate the type of the tool. :param Tool tool: The type of the tool to validate. :raises ValueError: If the tool type is not a subclass of Tool. """ if isinstance(tool, FunctionTool): raise ValueError( "FunctionTool is not supported in AsyncToolSet. " + "Please use AsyncFunctionTool instead and provide sync and/or async function(s)." )
[docs] async def execute_tool_calls(self, tool_calls: List[Any]) -> Any: """ Execute a tool of the specified type with the provided tool calls. :param List[Any] tool_calls: A list of tool calls to execute. :return: The output of the tool operations. :rtype: Any """ tool_outputs = [] for tool_call in tool_calls: try: if tool_call.type == "function": tool = self.get_tool(AsyncFunctionTool) output = await tool.execute(tool_call) tool_output = { "tool_call_id": tool_call.id, "output": output, } tool_outputs.append(tool_output) except Exception as e: # pylint: disable=broad-exception-caught logging.error("Failed to execute tool call %s: %s", tool_call, e) return tool_outputs
EventFunctionReturnT = TypeVar("EventFunctionReturnT") T = TypeVar("T") BaseAsyncAgentEventHandlerT = TypeVar("BaseAsyncAgentEventHandlerT", bound="BaseAsyncAgentEventHandler") BaseAgentEventHandlerT = TypeVar("BaseAgentEventHandlerT", bound="BaseAgentEventHandler") async def async_chain(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: for iterator in iterators: async for item in iterator: yield item
[docs] class BaseAsyncAgentEventHandler(AsyncIterator[T]): def __init__(self) -> None: self.response_iterator: Optional[AsyncIterator[bytes]] = None self.submit_tool_outputs: Optional[Callable[[ThreadRun, "BaseAsyncAgentEventHandler[T]"], Awaitable[None]]] = ( None ) self.buffer: Optional[str] = None
[docs] def initialize( self, response_iterator: AsyncIterator[bytes], submit_tool_outputs: Callable[[ThreadRun, "BaseAsyncAgentEventHandler[T]"], Awaitable[None]], ): self.response_iterator = ( async_chain(self.response_iterator, response_iterator) if self.response_iterator else response_iterator ) self.submit_tool_outputs = submit_tool_outputs
async def __anext__(self) -> T: self.buffer = "" if self.buffer is None else self.buffer if self.response_iterator is None: raise ValueError("The response handler was not initialized.") if not "\n\n" in self.buffer: async for chunk in self.response_iterator: self.buffer += chunk.decode("utf-8") if "\n\n" in self.buffer: break if self.buffer == "": raise StopAsyncIteration() event_str = "" if "\n\n" in self.buffer: event_end_index = self.buffer.index("\n\n") event_str = self.buffer[:event_end_index] self.buffer = self.buffer[event_end_index:].lstrip() else: event_str = self.buffer self.buffer = "" return await self._process_event(event_str) async def _process_event(self, event_data_str: str) -> T: raise NotImplementedError("This method needs to be implemented.")
[docs] async def until_done(self) -> None: """ Iterates through all events until the stream is marked as done. Calls the provided callback function with each event data. """ try: async for _ in self: pass except StopAsyncIteration: pass
[docs] class BaseAgentEventHandler(Iterator[T]): def __init__(self) -> None: self.response_iterator: Optional[Iterator[bytes]] = None self.submit_tool_outputs: Optional[Callable[[ThreadRun, "BaseAgentEventHandler[T]"], None]] = None self.buffer: Optional[str] = None
[docs] def initialize( self, response_iterator: Iterator[bytes], submit_tool_outputs: Callable[[ThreadRun, "BaseAgentEventHandler[T]"], None], ) -> None: self.response_iterator = ( itertools.chain(self.response_iterator, response_iterator) if self.response_iterator else response_iterator ) self.submit_tool_outputs = submit_tool_outputs
def __next__(self) -> T: self.buffer = "" if self.buffer is None else self.buffer if self.response_iterator is None: raise ValueError("The response handler was not initialized.") if not "\n\n" in self.buffer: for chunk in self.response_iterator: self.buffer += chunk.decode("utf-8") if "\n\n" in self.buffer: break if self.buffer == "": raise StopIteration() event_str = "" if "\n\n" in self.buffer: event_end_index = self.buffer.index("\n\n") event_str = self.buffer[:event_end_index] self.buffer = self.buffer[event_end_index:].lstrip() else: event_str = self.buffer self.buffer = "" return self._process_event(event_str) def _process_event(self, event_data_str: str) -> T: raise NotImplementedError("This method needs to be implemented.")
[docs] def until_done(self) -> None: """ Iterates through all events until the stream is marked as done. Calls the provided callback function with each event data. """ try: for _ in self: pass except StopIteration: pass
[docs] class AsyncAgentEventHandler(BaseAsyncAgentEventHandler[Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]]): async def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]: event_type, event_data_obj = _parse_event(event_data_str) if ( isinstance(event_data_obj, ThreadRun) and event_data_obj.status == "requires_action" and isinstance(event_data_obj.required_action, SubmitToolOutputsAction) ): await cast(Callable[[ThreadRun, "BaseAsyncAgentEventHandler"], Awaitable[None]], self.submit_tool_outputs)( event_data_obj, self ) func_rt: Optional[EventFunctionReturnT] = None try: if isinstance(event_data_obj, MessageDeltaChunk): func_rt = await self.on_message_delta(event_data_obj) elif isinstance(event_data_obj, ThreadMessage): func_rt = await self.on_thread_message(event_data_obj) elif isinstance(event_data_obj, ThreadRun): func_rt = await self.on_thread_run(event_data_obj) elif isinstance(event_data_obj, RunStep): func_rt = await self.on_run_step(event_data_obj) elif isinstance(event_data_obj, RunStepDeltaChunk): func_rt = await self.on_run_step_delta(event_data_obj) elif event_type == AgentStreamEvent.ERROR: func_rt = await self.on_error(event_data_obj) elif event_type == AgentStreamEvent.DONE: func_rt = await self.on_done() else: func_rt = await self.on_unhandled_event( event_type, event_data_obj ) # pylint: disable=assignment-from-none except Exception as e: # pylint: disable=broad-exception-caught logging.error("Error in event handler for event '%s': %s", event_type, e) return event_type, event_data_obj, func_rt
[docs] async def on_message_delta( self, delta: "MessageDeltaChunk" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle message delta events. :param MessageDeltaChunk delta: The message delta. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_thread_message( self, message: "ThreadMessage" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle thread message events. :param ThreadMessage message: The thread message. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_thread_run( self, run: "ThreadRun" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle thread run events. :param ThreadRun run: The thread run. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_run_step(self, step: "RunStep") -> Optional[EventFunctionReturnT]: # pylint: disable=unused-argument """Handle run step events. :param RunStep step: The run step. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_run_step_delta( self, delta: "RunStepDeltaChunk" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle run step delta events. :param RunStepDeltaChunk delta: The run step delta. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_error(self, data: str) -> Optional[EventFunctionReturnT]: # pylint: disable=unused-argument """Handle error events. :param str data: The error event's data. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_done( self, ) -> Optional[EventFunctionReturnT]: """Handle the completion of the stream. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] async def on_unhandled_event( self, event_type: str, event_data: str # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle any unhandled event types. :param str event_type: The event type. :param Any event_data: The event's data. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] class AgentEventHandler(BaseAgentEventHandler[Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]]): def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]: event_type, event_data_obj = _parse_event(event_data_str) if ( isinstance(event_data_obj, ThreadRun) and event_data_obj.status == "requires_action" and isinstance(event_data_obj.required_action, SubmitToolOutputsAction) ): cast(Callable[[ThreadRun, "BaseAgentEventHandler"], Awaitable[None]], self.submit_tool_outputs)( event_data_obj, self ) func_rt: Optional[EventFunctionReturnT] = None try: if isinstance(event_data_obj, MessageDeltaChunk): func_rt = self.on_message_delta(event_data_obj) # pylint: disable=assignment-from-none elif isinstance(event_data_obj, ThreadMessage): func_rt = self.on_thread_message(event_data_obj) # pylint: disable=assignment-from-none elif isinstance(event_data_obj, ThreadRun): func_rt = self.on_thread_run(event_data_obj) # pylint: disable=assignment-from-none elif isinstance(event_data_obj, RunStep): func_rt = self.on_run_step(event_data_obj) # pylint: disable=assignment-from-none elif isinstance(event_data_obj, RunStepDeltaChunk): func_rt = self.on_run_step_delta(event_data_obj) # pylint: disable=assignment-from-none elif event_type == AgentStreamEvent.ERROR: func_rt = self.on_error(event_data_obj) # pylint: disable=assignment-from-none elif event_type == AgentStreamEvent.DONE: func_rt = self.on_done() # pylint: disable=assignment-from-none else: func_rt = self.on_unhandled_event(event_type, event_data_obj) # pylint: disable=assignment-from-none except Exception as e: # pylint: disable=broad-exception-caught logging.error("Error in event handler for event '%s': %s", event_type, e) return event_type, event_data_obj, func_rt
[docs] def on_message_delta( self, delta: "MessageDeltaChunk" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle message delta events. :param MessageDeltaChunk delta: The message delta. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] def on_thread_message( self, message: "ThreadMessage" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle thread message events. :param ThreadMessage message: The thread message. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] def on_thread_run(self, run: "ThreadRun") -> Optional[EventFunctionReturnT]: # pylint: disable=unused-argument """Handle thread run events. :param ThreadRun run: The thread run. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] def on_run_step(self, step: "RunStep") -> Optional[EventFunctionReturnT]: # pylint: disable=unused-argument """Handle run step events. :param RunStep step: The run step. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] def on_run_step_delta( self, delta: "RunStepDeltaChunk" # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle run step delta events. :param RunStepDeltaChunk delta: The run step delta. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] def on_error(self, data: str) -> Optional[EventFunctionReturnT]: # pylint: disable=unused-argument """Handle error events. :param str data: The error event's data. :rtype: Optional[EventFunctionReturnT] """ return None
[docs] def on_done( self, ) -> Optional[EventFunctionReturnT]: """Handle the completion of the stream.""" return None
[docs] def on_unhandled_event( self, event_type: str, event_data: str # pylint: disable=unused-argument ) -> Optional[EventFunctionReturnT]: """Handle any unhandled event types. :param str event_type: The event type. :param Any event_data: The event's data. """ return None
[docs] class AsyncAgentRunStream(Generic[BaseAsyncAgentEventHandlerT]): def __init__( self, response_iterator: AsyncIterator[bytes], submit_tool_outputs: Callable[[ThreadRun, BaseAsyncAgentEventHandlerT], Awaitable[None]], event_handler: BaseAsyncAgentEventHandlerT, ): self.response_iterator = response_iterator self.event_handler = event_handler self.submit_tool_outputs = submit_tool_outputs self.event_handler.initialize( self.response_iterator, cast(Callable[[ThreadRun, BaseAsyncAgentEventHandler], Awaitable[None]], submit_tool_outputs), ) async def __aenter__(self): return self.event_handler async def __aexit__(self, exc_type, exc_val, exc_tb): close_method = getattr(self.response_iterator, "close", None) if callable(close_method): result = close_method() if asyncio.iscoroutine(result): await result
[docs] class AgentRunStream(Generic[BaseAgentEventHandlerT]): def __init__( self, response_iterator: Iterator[bytes], submit_tool_outputs: Callable[[ThreadRun, BaseAgentEventHandlerT], None], event_handler: BaseAgentEventHandlerT, ): self.response_iterator = response_iterator self.event_handler = event_handler self.submit_tool_outputs = submit_tool_outputs self.event_handler.initialize( self.response_iterator, cast(Callable[[ThreadRun, BaseAgentEventHandler], None], submit_tool_outputs), ) def __enter__(self): return self.event_handler def __exit__(self, exc_type, exc_val, exc_tb): close_method = getattr(self.response_iterator, "close", None) if callable(close_method): close_method()
[docs] class OpenAIPageableListOfThreadMessage(OpenAIPageableListOfThreadMessageGenerated): @property def text_messages(self) -> List[MessageTextContent]: """Returns all text message contents in the messages. :rtype: List[MessageTextContent] """ texts = [content for msg in self.data for content in msg.text_messages] return texts @property def image_contents(self) -> List[MessageImageFileContent]: """Returns all image file contents from image message contents in the messages. :rtype: List[MessageImageFileContent] """ return [content for msg in self.data for content in msg.image_contents] @property def file_citation_annotations(self) -> List[MessageTextFileCitationAnnotation]: """Returns all file citation annotations from text message annotations in the messages. :rtype: List[MessageTextFileCitationAnnotation] """ annotations = [annotation for msg in self.data for annotation in msg.file_citation_annotations] return annotations @property def file_path_annotations(self) -> List[MessageTextFilePathAnnotation]: """Returns all file path annotations from text message annotations in the messages. :rtype: List[MessageTextFilePathAnnotation] """ annotations = [annotation for msg in self.data for annotation in msg.file_path_annotations] return annotations
[docs] def get_last_message_by_role(self, role: MessageRole) -> Optional[ThreadMessage]: """Returns the last message from a sender in the specified role. :param role: The role of the sender. :type role: MessageRole :return: The last message from a sender in the specified role. :rtype: ~azure.ai.projects.models.ThreadMessage """ for msg in self.data: if msg.role == role: return msg return None
[docs] def get_last_text_message_by_role(self, role: MessageRole) -> Optional[MessageTextContent]: """Returns the last text message from a sender in the specified role. :param role: The role of the sender. :type role: MessageRole :return: The last text message from a sender in the specified role. :rtype: ~azure.ai.projects.models.MessageTextContent """ for msg in self.data: if msg.role == role: for content in msg.content: if isinstance(content, MessageTextContent): return content return None
__all__: List[str] = [ "AgentEventHandler", "AgentRunStream", "AsyncAgentRunStream", "AsyncFunctionTool", "AsyncToolSet", "AzureAISearchTool", "AzureFunctionTool", "BaseAsyncAgentEventHandler", "BaseAgentEventHandler", "CodeInterpreterTool", "ConnectionProperties", "AsyncAgentEventHandler", "OpenAIPageableListOfThreadMessage", "FileSearchTool", "FunctionTool", "OpenApiTool", "BingGroundingTool", "StreamEventData", "SharepointTool", "FabricTool", "AzureAISearchTool", "SASTokenCredential", "Tool", "ToolSet", "BaseAsyncAgentEventHandlerT", "BaseAgentEventHandlerT", "ThreadMessage", "MessageTextFileCitationAnnotation", "MessageDeltaChunk", "MessageAttachment", ] # Add all objects you want publicly available to users at this package level def patch_sdk(): """Do not remove from this file. `patch_sdk` is a last resort escape hatch that allows you to do customizations you can't accomplish using the techniques described in https://aka.ms/azsdk/python/dpcodegen/python/customize """