Source code for azure.ai.inference.models._patch

# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Customize generated code here.

Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
import base64
import json
import logging
import queue
import re
import sys

from typing import Mapping, Literal, Any, List, AsyncIterator, Iterator, Optional, Union, overload
from azure.core.rest import HttpResponse, AsyncHttpResponse
from ._enums import ChatRole
from .._model_base import rest_discriminator, rest_field
from ._models import ChatRequestMessage
from ._models import ImageUrl as ImageUrlGenerated
from ._models import ChatCompletions as ChatCompletionsGenerated
from ._models import EmbeddingsResult as EmbeddingsResultGenerated
from ._models import ImageEmbeddingInput as EmbeddingInputGenerated
from ._models import InputAudio as InputAudioGenerated
from .. import models as _models

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self

logger = logging.getLogger(__name__)


[docs] class UserMessage(ChatRequestMessage, discriminator="user"): """A request chat message representing user input to the assistant. :ivar role: The chat role associated with this message, which is always 'user' for user messages. Required. The role that provides input for chat completions. :vartype role: str or ~azure.ai.inference.models.USER :ivar content: The contents of the user message, with available input types varying by selected model. Required. Is either a str type or a [ContentItem] type. :vartype content: str or list[~azure.ai.inference.models.ContentItem] """ role: Literal[ChatRole.USER] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'user' for user messages. Required. The role that provides input for chat completions.""" content: Union["str", List["_models.ContentItem"]] = rest_field() """The contents of the user message, with available input types varying by selected model. Required. Is either a str type or a [ContentItem] type.""" @overload def __init__( self, content: Union[str, List["_models.ContentItem"]], ) -> None: ... @overload def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == 1 and isinstance(args[0], (List, str)): if kwargs.get("content") is not None: raise ValueError("content cannot be provided as positional and keyword arguments") kwargs["content"] = args[0] args = tuple() super().__init__(*args, role=ChatRole.USER, **kwargs)
[docs] class SystemMessage(ChatRequestMessage, discriminator="system"): """A request chat message containing system instructions that influence how the model will generate a chat completions response. :ivar role: The chat role associated with this message, which is always 'system' for system messages. Required. :vartype role: str or ~azure.ai.inference.models.SYSTEM :ivar content: The contents of the system message. Required. :vartype content: str """ role: Literal[ChatRole.SYSTEM] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'system' for system messages. Required.""" content: str = rest_field() """The contents of the system message. Required.""" @overload def __init__( self, content: str, ) -> None: ... @overload def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == 1 and isinstance(args[0], str): if kwargs.get("content") is not None: raise ValueError("content cannot be provided as positional and keyword arguments") kwargs["content"] = args[0] args = tuple() super().__init__(*args, role=ChatRole.SYSTEM, **kwargs)
[docs] class DeveloperMessage(ChatRequestMessage, discriminator="developer"): """A request chat message containing developer instructions that influence how the model will generate a chat completions response. Some AI models support developer messages instead of system messages. :ivar role: The chat role associated with this message, which is always 'developer' for developer messages. Required. :vartype role: str or ~azure.ai.inference.models.DEVELOPER :ivar content: The contents of the developer message. Required. :vartype content: str """ role: Literal[ChatRole.DEVELOPER] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'developer' for developer messages. Required.""" content: str = rest_field() """The contents of the developer message. Required.""" @overload def __init__( self, content: str, ) -> None: ... @overload def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == 1 and isinstance(args[0], str): if kwargs.get("content") is not None: raise ValueError("content cannot be provided as positional and keyword arguments") kwargs["content"] = args[0] args = tuple() super().__init__(*args, role=ChatRole.DEVELOPER, **kwargs)
[docs] class AssistantMessage(ChatRequestMessage, discriminator="assistant"): """A request chat message representing response or action from the assistant. :ivar role: The chat role associated with this message, which is always 'assistant' for assistant messages. Required. The role that provides responses to system-instructed, user-prompted input. :vartype role: str or ~azure.ai.inference.models.ASSISTANT :ivar content: The content of the message. :vartype content: str :ivar tool_calls: The tool calls that must be resolved and have their outputs appended to subsequent input messages for the chat completions request to resolve as configured. :vartype tool_calls: list[~azure.ai.inference.models.ChatCompletionsToolCall] """ role: Literal[ChatRole.ASSISTANT] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'assistant' for assistant messages. Required. The role that provides responses to system-instructed, user-prompted input.""" content: Optional[str] = rest_field() """The content of the message.""" tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = rest_field() """The tool calls that must be resolved and have their outputs appended to subsequent input messages for the chat completions request to resolve as configured.""" @overload def __init__( self, content: Optional[str] = None, *, tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, ) -> None: ... @overload def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == 1 and isinstance(args[0], str): if kwargs.get("content") is not None: raise ValueError("content cannot be provided as positional and keyword arguments") kwargs["content"] = args[0] args = tuple() super().__init__(*args, role=ChatRole.ASSISTANT, **kwargs)
[docs] class ToolMessage(ChatRequestMessage, discriminator="tool"): """A request chat message representing requested output from a configured tool. :ivar role: The chat role associated with this message, which is always 'tool' for tool messages. Required. The role that represents extension tool activity within a chat completions operation. :vartype role: str or ~azure.ai.inference.models.TOOL :ivar content: The content of the message. :vartype content: str :ivar tool_call_id: The ID of the tool call resolved by the provided content. Required. :vartype tool_call_id: str """ role: Literal[ChatRole.TOOL] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'tool' for tool messages. Required. The role that represents extension tool activity within a chat completions operation.""" content: Optional[str] = rest_field() """The content of the message.""" tool_call_id: str = rest_field() """The ID of the tool call resolved by the provided content. Required.""" @overload def __init__( self, content: Optional[str] = None, *, tool_call_id: str, ) -> None: ... @overload def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == 1 and isinstance(args[0], str): if kwargs.get("content") is not None: raise ValueError("content cannot be provided as positional and keyword arguments") kwargs["content"] = args[0] args = tuple() super().__init__(*args, role=ChatRole.TOOL, **kwargs)
[docs] class ChatCompletions(ChatCompletionsGenerated): """Representation of the response data from a chat completions request. Completions support a wide variety of tasks and generate text that continues from or "completes" provided prompt data. :ivar id: A unique identifier associated with this chat completions response. Required. :vartype id: str :ivar created: The first timestamp associated with generation activity for this completions response, represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required. :vartype created: ~datetime.datetime :ivar model: The model used for the chat completion. Required. :vartype model: str :ivar usage: Usage information for tokens processed and generated as part of this completions operation. Required. :vartype usage: ~azure.ai.inference.models.CompletionsUsage :ivar choices: The collection of completions choices associated with this completions response. Generally, ``n`` choices are generated per provided prompt with a default value of 1. Token limits and other settings may limit the number of choices generated. Required. :vartype choices: list[~azure.ai.inference.models.ChatChoice] """ def __str__(self) -> str: # pylint: disable=client-method-name-no-double-underscore return json.dumps(self.as_dict(), indent=2)
[docs] class EmbeddingsResult(EmbeddingsResultGenerated): """Representation of the response data from an embeddings request. Embeddings measure the relatedness of text strings and are commonly used for search, clustering, recommendations, and other similar scenarios. :ivar data: Embedding values for the prompts submitted in the request. Required. :vartype data: list[~azure.ai.inference.models.EmbeddingItem] :ivar usage: Usage counts for tokens input using the embeddings API. Required. :vartype usage: ~azure.ai.inference.models.EmbeddingsUsage :ivar model: The model ID used to generate this result. Required. :vartype model: str """ def __str__(self) -> str: # pylint: disable=client-method-name-no-double-underscore return json.dumps(self.as_dict(), indent=2)
[docs] class ImageUrl(ImageUrlGenerated):
[docs] @classmethod def load( cls, *, image_file: str, image_format: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None ) -> Self: """ Create an ImageUrl object from a local image file. The method reads the image file and encodes it as a base64 string, which together with the image format is then used to format the JSON `url` value passed in the request payload. :keyword image_file: The name of the local image file to load. Required. :paramtype image_file: str :keyword image_format: The MIME type format of the image. For example: "jpeg", "png". Required. :paramtype image_format: str :keyword detail: The evaluation quality setting to use, which controls relative prioritization of speed, token consumption, and accuracy. Known values are: "auto", "low", and "high". :paramtype detail: str or ~azure.ai.inference.models.ImageDetailLevel :return: An ImageUrl object with the image data encoded as a base64 string. :rtype: ~azure.ai.inference.models.ImageUrl :raises FileNotFoundError: when the image file could not be opened. """ with open(image_file, "rb") as f: image_data = base64.b64encode(f.read()).decode("utf-8") url = f"data:image/{image_format};base64,{image_data}" return cls(url=url, detail=detail)
[docs] class ImageEmbeddingInput(EmbeddingInputGenerated):
[docs] @classmethod def load(cls, *, image_file: str, image_format: str, text: Optional[str] = None) -> Self: """ Create an ImageEmbeddingInput object from a local image file. The method reads the image file and encodes it as a base64 string, which together with the image format is then used to format the JSON `url` value passed in the request payload. :keyword image_file: The name of the local image file to load. Required. :paramtype image_file: str :keyword image_format: The MIME type format of the image. For example: "jpeg", "png". Required. :paramtype image_format: str :keyword text: Optional. The text input to feed into the model (like DINO, CLIP). Returns a 422 error if the model doesn't support the value or parameter. :paramtype text: str :return: An ImageEmbeddingInput object with the image data encoded as a base64 string. :rtype: ~azure.ai.inference.models.EmbeddingsInput :raises FileNotFoundError: when the image file could not be opened. """ with open(image_file, "rb") as f: image_data = base64.b64encode(f.read()).decode("utf-8") image_uri = f"data:image/{image_format};base64,{image_data}" return cls(image=image_uri, text=text)
class BaseStreamingChatCompletions: """A base class for the sync and async streaming chat completions responses, holding any common code to deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object. """ # Enable detailed logs of SSE parsing. For development only, should be `False` by default. _ENABLE_CLASS_LOGS = False # The prefix of each line in the SSE stream that contains a JSON string # to deserialize into a StreamingChatCompletionsUpdate object _SSE_DATA_EVENT_PREFIX = b"data: " # The line indicating the end of the SSE stream _SSE_DATA_EVENT_DONE = b"data: [DONE]" def __init__(self): self._queue: "queue.Queue[_models.StreamingChatCompletionsUpdate]" = queue.Queue() self._incomplete_line = b"" self._done = False # Will be set to True when reading 'data: [DONE]' line # See https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream def _deserialize_and_add_to_queue(self, element: bytes) -> bool: if self._ENABLE_CLASS_LOGS: logger.debug("[Original element] %s", repr(element)) # Clear the queue of StreamingChatCompletionsUpdate before processing the next block self._queue.queue.clear() # Split the single input bytes object at new line characters, and get a list of bytes objects, each # representing a single "line". The bytes object at the end of the list may be a partial "line" that # does not contain a new line character at the end. # Note 1: DO NOT try to use something like this here: # line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8")) # to do full UTF8 decoding of the whole input bytes object, as the last line in the list may be partial, and # as such may contain a partial UTF8 Chinese character (for example). `decode("utf-8")` will raise an # exception for such a case. See GitHub issue https://github.com/Azure/azure-sdk-for-python/issues/39565 # Note 2: Consider future re-write and simplifications of this code by using: # `codecs.getincrementaldecoder("utf-8")` line_list: List[bytes] = re.split(re.compile(b"(?<=\n)"), element) for index, line in enumerate(line_list): if self._ENABLE_CLASS_LOGS: logger.debug("[Original line] %s", repr(line)) if index == 0: line = self._incomplete_line + line self._incomplete_line = b"" if index == len(line_list) - 1 and not line.endswith(b"\n"): self._incomplete_line = line return False if self._ENABLE_CLASS_LOGS: logger.debug("[Modified line] %s", repr(line)) if line == b"\n": # Empty line, indicating flush output to client continue if not line.startswith(self._SSE_DATA_EVENT_PREFIX): raise ValueError(f"SSE event not supported (line `{repr(line)}`)") if line.startswith(self._SSE_DATA_EVENT_DONE): if self._ENABLE_CLASS_LOGS: logger.debug("[Done]") return True # If you reached here, the line should contain `data: {...}\n` # where the curly braces contain a valid JSON object. # It is now safe to do UTF8 decoding of the line. line_str = line.decode("utf-8") # Deserialize it into a StreamingChatCompletionsUpdate object # and add it to the queue. # pylint: disable=W0212 # Access to a protected member _deserialize of a client class update = _models.StreamingChatCompletionsUpdate._deserialize( json.loads(line_str[len(self._SSE_DATA_EVENT_PREFIX) : -1]), [] ) # We skip any update that has a None or empty choices list, and does not have token usage info. # (this is what OpenAI Python SDK does) if update.choices or update.usage: self._queue.put(update) if self._ENABLE_CLASS_LOGS: logger.debug("[Added to queue]") return False
[docs] class StreamingChatCompletions(BaseStreamingChatCompletions): """Represents an interator over StreamingChatCompletionsUpdate objects. It can be used for either synchronous or asynchronous iterations. The class deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object. """ def __init__(self, response: HttpResponse): super().__init__() self._response = response self._bytes_iterator: Iterator[bytes] = response.iter_bytes() def __iter__(self) -> Any: return self def __next__(self) -> "_models.StreamingChatCompletionsUpdate": while self._queue.empty() and not self._done: self._done = self._read_next_block() if self._queue.empty(): raise StopIteration return self._queue.get() def _read_next_block(self) -> bool: if self._ENABLE_CLASS_LOGS: logger.debug("[Reading next block]") try: element = self._bytes_iterator.__next__() except StopIteration: self.close() return True return self._deserialize_and_add_to_queue(element) def __enter__(self): return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore self.close()
[docs] def close(self) -> None: self._response.close()
[docs] class AsyncStreamingChatCompletions(BaseStreamingChatCompletions): """Represents an async interator over StreamingChatCompletionsUpdate objects. It can be used for either synchronous or asynchronous iterations. The class deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object. """ def __init__(self, response: AsyncHttpResponse): super().__init__() self._response = response self._bytes_iterator: AsyncIterator[bytes] = response.iter_bytes() def __aiter__(self) -> Any: return self async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate": while self._queue.empty() and not self._done: self._done = await self._read_next_block_async() if self._queue.empty(): raise StopAsyncIteration return self._queue.get() async def _read_next_block_async(self) -> bool: if self._ENABLE_CLASS_LOGS: logger.debug("[Reading next block]") try: element = await self._bytes_iterator.__anext__() except StopAsyncIteration: await self.aclose() return True return self._deserialize_and_add_to_queue(element) async def __aenter__(self): return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore await self.aclose()
[docs] async def aclose(self) -> None: await self._response.close()
[docs] class InputAudio(InputAudioGenerated):
[docs] @classmethod def load( cls, *, audio_file: str, audio_format: str, ) -> Self: """ Create an InputAudio object from a local audio file. The method reads the audio file and encodes it as a base64 string, which together with the audio format is then used to create the InputAudio object passed to the request payload. :keyword audio_file: The name of the local audio file to load. Required. :vartype audio_file: str :keyword audio_format: The MIME type format of the audio. For example: "wav", "mp3". Required. :vartype audio_format: str :return: An InputAudio object with the audio data encoded as a base64 string. :rtype: ~azure.ai.inference.models.InputAudio :raises FileNotFoundError: when the image file could not be opened. """ with open(audio_file, "rb") as f: audio_data = base64.b64encode(f.read()).decode("utf-8") return cls(data=audio_data, format=audio_format)
__all__: List[str] = [ "AssistantMessage", "AsyncStreamingChatCompletions", "ChatCompletions", "ChatRequestMessage", "EmbeddingsResult", "ImageEmbeddingInput", "ImageUrl", "InputAudio", "StreamingChatCompletions", "SystemMessage", "ToolMessage", "UserMessage", "DeveloperMessage", ] # Add all objects you want publicly available to users at this package level def patch_sdk(): """Do not remove from this file. `patch_sdk` is a last resort escape hatch that allows you to do customizations you can't accomplish using the techniques described in https://aka.ms/azsdk/python/dpcodegen/python/customize """