# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Customize generated code here.
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
import base64
import json
import logging
import queue
import re
import sys
from typing import Mapping, Literal, Any, List, AsyncIterator, Iterator, Optional, Union, overload
from azure.core.rest import HttpResponse, AsyncHttpResponse
from ._enums import ChatRole
from .._model_base import rest_discriminator, rest_field
from ._models import ChatRequestMessage
from ._models import ImageUrl as ImageUrlGenerated
from ._models import ChatCompletions as ChatCompletionsGenerated
from ._models import EmbeddingsResult as EmbeddingsResultGenerated
from ._models import ImageEmbeddingInput as EmbeddingInputGenerated
from ._models import InputAudio as InputAudioGenerated
from .. import models as _models
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
logger = logging.getLogger(__name__)
[docs]
class UserMessage(ChatRequestMessage, discriminator="user"):
"""A request chat message representing user input to the assistant.
:ivar role: The chat role associated with this message, which is always 'user' for user
messages. Required. The role that provides input for chat completions.
:vartype role: str or ~azure.ai.inference.models.USER
:ivar content: The contents of the user message, with available input types varying by selected
model. Required. Is either a str type or a [ContentItem] type.
:vartype content: str or list[~azure.ai.inference.models.ContentItem]
"""
role: Literal[ChatRole.USER] = rest_discriminator(name="role") # type: ignore
"""The chat role associated with this message, which is always 'user' for user messages. Required.
The role that provides input for chat completions."""
content: Union["str", List["_models.ContentItem"]] = rest_field()
"""The contents of the user message, with available input types varying by selected model.
Required. Is either a str type or a [ContentItem] type."""
@overload
def __init__(
self,
content: Union[str, List["_models.ContentItem"]],
) -> None: ...
@overload
def __init__(self, mapping: Mapping[str, Any]) -> None:
"""
:param mapping: raw JSON to initialize the model.
:type mapping: Mapping[str, Any]
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 1 and isinstance(args[0], (List, str)):
if kwargs.get("content") is not None:
raise ValueError("content cannot be provided as positional and keyword arguments")
kwargs["content"] = args[0]
args = tuple()
super().__init__(*args, role=ChatRole.USER, **kwargs)
[docs]
class SystemMessage(ChatRequestMessage, discriminator="system"):
"""A request chat message containing system instructions that influence how the model will
generate a chat completions response.
:ivar role: The chat role associated with this message, which is always 'system' for system
messages. Required.
:vartype role: str or ~azure.ai.inference.models.SYSTEM
:ivar content: The contents of the system message. Required.
:vartype content: str
"""
role: Literal[ChatRole.SYSTEM] = rest_discriminator(name="role") # type: ignore
"""The chat role associated with this message, which is always 'system' for system messages.
Required."""
content: str = rest_field()
"""The contents of the system message. Required."""
@overload
def __init__(
self,
content: str,
) -> None: ...
@overload
def __init__(self, mapping: Mapping[str, Any]) -> None:
"""
:param mapping: raw JSON to initialize the model.
:type mapping: Mapping[str, Any]
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 1 and isinstance(args[0], str):
if kwargs.get("content") is not None:
raise ValueError("content cannot be provided as positional and keyword arguments")
kwargs["content"] = args[0]
args = tuple()
super().__init__(*args, role=ChatRole.SYSTEM, **kwargs)
[docs]
class DeveloperMessage(ChatRequestMessage, discriminator="developer"):
"""A request chat message containing developer instructions that influence how the model will
generate a chat completions response. Some AI models support developer messages instead
of system messages.
:ivar role: The chat role associated with this message, which is always 'developer' for developer
messages. Required.
:vartype role: str or ~azure.ai.inference.models.DEVELOPER
:ivar content: The contents of the developer message. Required.
:vartype content: str
"""
role: Literal[ChatRole.DEVELOPER] = rest_discriminator(name="role") # type: ignore
"""The chat role associated with this message, which is always 'developer' for developer messages.
Required."""
content: str = rest_field()
"""The contents of the developer message. Required."""
@overload
def __init__(
self,
content: str,
) -> None: ...
@overload
def __init__(self, mapping: Mapping[str, Any]) -> None:
"""
:param mapping: raw JSON to initialize the model.
:type mapping: Mapping[str, Any]
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 1 and isinstance(args[0], str):
if kwargs.get("content") is not None:
raise ValueError("content cannot be provided as positional and keyword arguments")
kwargs["content"] = args[0]
args = tuple()
super().__init__(*args, role=ChatRole.DEVELOPER, **kwargs)
[docs]
class AssistantMessage(ChatRequestMessage, discriminator="assistant"):
"""A request chat message representing response or action from the assistant.
:ivar role: The chat role associated with this message, which is always 'assistant' for
assistant messages. Required. The role that provides responses to system-instructed,
user-prompted input.
:vartype role: str or ~azure.ai.inference.models.ASSISTANT
:ivar content: The content of the message.
:vartype content: str
:ivar tool_calls: The tool calls that must be resolved and have their outputs appended to
subsequent input messages for the chat
completions request to resolve as configured.
:vartype tool_calls: list[~azure.ai.inference.models.ChatCompletionsToolCall]
"""
role: Literal[ChatRole.ASSISTANT] = rest_discriminator(name="role") # type: ignore
"""The chat role associated with this message, which is always 'assistant' for assistant messages.
Required. The role that provides responses to system-instructed, user-prompted input."""
content: Optional[str] = rest_field()
"""The content of the message."""
tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = rest_field()
"""The tool calls that must be resolved and have their outputs appended to subsequent input
messages for the chat
completions request to resolve as configured."""
@overload
def __init__(
self,
content: Optional[str] = None,
*,
tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None,
) -> None: ...
@overload
def __init__(self, mapping: Mapping[str, Any]) -> None:
"""
:param mapping: raw JSON to initialize the model.
:type mapping: Mapping[str, Any]
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 1 and isinstance(args[0], str):
if kwargs.get("content") is not None:
raise ValueError("content cannot be provided as positional and keyword arguments")
kwargs["content"] = args[0]
args = tuple()
super().__init__(*args, role=ChatRole.ASSISTANT, **kwargs)
[docs]
class ChatCompletions(ChatCompletionsGenerated):
"""Representation of the response data from a chat completions request.
Completions support a wide variety of tasks and generate text that continues from or
"completes"
provided prompt data.
:ivar id: A unique identifier associated with this chat completions response. Required.
:vartype id: str
:ivar created: The first timestamp associated with generation activity for this completions
response,
represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required.
:vartype created: ~datetime.datetime
:ivar model: The model used for the chat completion. Required.
:vartype model: str
:ivar usage: Usage information for tokens processed and generated as part of this completions
operation. Required.
:vartype usage: ~azure.ai.inference.models.CompletionsUsage
:ivar choices: The collection of completions choices associated with this completions response.
Generally, ``n`` choices are generated per provided prompt with a default value of 1.
Token limits and other settings may limit the number of choices generated. Required.
:vartype choices: list[~azure.ai.inference.models.ChatChoice]
"""
def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return json.dumps(self.as_dict(), indent=2)
[docs]
class EmbeddingsResult(EmbeddingsResultGenerated):
"""Representation of the response data from an embeddings request.
Embeddings measure the relatedness of text strings and are commonly used for search,
clustering,
recommendations, and other similar scenarios.
:ivar data: Embedding values for the prompts submitted in the request. Required.
:vartype data: list[~azure.ai.inference.models.EmbeddingItem]
:ivar usage: Usage counts for tokens input using the embeddings API. Required.
:vartype usage: ~azure.ai.inference.models.EmbeddingsUsage
:ivar model: The model ID used to generate this result. Required.
:vartype model: str
"""
def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return json.dumps(self.as_dict(), indent=2)
[docs]
class ImageUrl(ImageUrlGenerated):
[docs]
@classmethod
def load(
cls, *, image_file: str, image_format: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None
) -> Self:
"""
Create an ImageUrl object from a local image file. The method reads the image
file and encodes it as a base64 string, which together with the image format
is then used to format the JSON `url` value passed in the request payload.
:keyword image_file: The name of the local image file to load. Required.
:paramtype image_file: str
:keyword image_format: The MIME type format of the image. For example: "jpeg", "png". Required.
:paramtype image_format: str
:keyword detail: The evaluation quality setting to use, which controls relative prioritization of
speed, token consumption, and accuracy. Known values are: "auto", "low", and "high".
:paramtype detail: str or ~azure.ai.inference.models.ImageDetailLevel
:return: An ImageUrl object with the image data encoded as a base64 string.
:rtype: ~azure.ai.inference.models.ImageUrl
:raises FileNotFoundError: when the image file could not be opened.
"""
with open(image_file, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
url = f"data:image/{image_format};base64,{image_data}"
return cls(url=url, detail=detail)
class BaseStreamingChatCompletions:
"""A base class for the sync and async streaming chat completions responses, holding any common code
to deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one
represented by a StreamingChatCompletionsUpdate object.
"""
# Enable detailed logs of SSE parsing. For development only, should be `False` by default.
_ENABLE_CLASS_LOGS = False
# The prefix of each line in the SSE stream that contains a JSON string
# to deserialize into a StreamingChatCompletionsUpdate object
_SSE_DATA_EVENT_PREFIX = b"data: "
# The line indicating the end of the SSE stream
_SSE_DATA_EVENT_DONE = b"data: [DONE]"
def __init__(self):
self._queue: "queue.Queue[_models.StreamingChatCompletionsUpdate]" = queue.Queue()
self._incomplete_line = b""
self._done = False # Will be set to True when reading 'data: [DONE]' line
# See https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
def _deserialize_and_add_to_queue(self, element: bytes) -> bool:
if self._ENABLE_CLASS_LOGS:
logger.debug("[Original element] %s", repr(element))
# Clear the queue of StreamingChatCompletionsUpdate before processing the next block
self._queue.queue.clear()
# Split the single input bytes object at new line characters, and get a list of bytes objects, each
# representing a single "line". The bytes object at the end of the list may be a partial "line" that
# does not contain a new line character at the end.
# Note 1: DO NOT try to use something like this here:
# line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8"))
# to do full UTF8 decoding of the whole input bytes object, as the last line in the list may be partial, and
# as such may contain a partial UTF8 Chinese character (for example). `decode("utf-8")` will raise an
# exception for such a case. See GitHub issue https://github.com/Azure/azure-sdk-for-python/issues/39565
# Note 2: Consider future re-write and simplifications of this code by using:
# `codecs.getincrementaldecoder("utf-8")`
line_list: List[bytes] = re.split(re.compile(b"(?<=\n)"), element)
for index, line in enumerate(line_list):
if self._ENABLE_CLASS_LOGS:
logger.debug("[Original line] %s", repr(line))
if index == 0:
line = self._incomplete_line + line
self._incomplete_line = b""
if index == len(line_list) - 1 and not line.endswith(b"\n"):
self._incomplete_line = line
return False
if self._ENABLE_CLASS_LOGS:
logger.debug("[Modified line] %s", repr(line))
if line == b"\n": # Empty line, indicating flush output to client
continue
if not line.startswith(self._SSE_DATA_EVENT_PREFIX):
raise ValueError(f"SSE event not supported (line `{repr(line)}`)")
if line.startswith(self._SSE_DATA_EVENT_DONE):
if self._ENABLE_CLASS_LOGS:
logger.debug("[Done]")
return True
# If you reached here, the line should contain `data: {...}\n`
# where the curly braces contain a valid JSON object.
# It is now safe to do UTF8 decoding of the line.
line_str = line.decode("utf-8")
# Deserialize it into a StreamingChatCompletionsUpdate object
# and add it to the queue.
# pylint: disable=W0212 # Access to a protected member _deserialize of a client class
update = _models.StreamingChatCompletionsUpdate._deserialize(
json.loads(line_str[len(self._SSE_DATA_EVENT_PREFIX) : -1]), []
)
# We skip any update that has a None or empty choices list, and does not have token usage info.
# (this is what OpenAI Python SDK does)
if update.choices or update.usage:
self._queue.put(update)
if self._ENABLE_CLASS_LOGS:
logger.debug("[Added to queue]")
return False
[docs]
class StreamingChatCompletions(BaseStreamingChatCompletions):
"""Represents an interator over StreamingChatCompletionsUpdate objects. It can be used for either synchronous or
asynchronous iterations. The class deserializes the Server Sent Events (SSE) response stream
into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object.
"""
def __init__(self, response: HttpResponse):
super().__init__()
self._response = response
self._bytes_iterator: Iterator[bytes] = response.iter_bytes()
def __iter__(self) -> Any:
return self
def __next__(self) -> "_models.StreamingChatCompletionsUpdate":
while self._queue.empty() and not self._done:
self._done = self._read_next_block()
if self._queue.empty():
raise StopIteration
return self._queue.get()
def _read_next_block(self) -> bool:
if self._ENABLE_CLASS_LOGS:
logger.debug("[Reading next block]")
try:
element = self._bytes_iterator.__next__()
except StopIteration:
self.close()
return True
return self._deserialize_and_add_to_queue(element)
def __enter__(self):
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore
self.close()
[docs]
def close(self) -> None:
self._response.close()
[docs]
class AsyncStreamingChatCompletions(BaseStreamingChatCompletions):
"""Represents an async interator over StreamingChatCompletionsUpdate objects.
It can be used for either synchronous or asynchronous iterations. The class
deserializes the Server Sent Events (SSE) response stream into chat
completions updates, each one represented by a StreamingChatCompletionsUpdate object.
"""
def __init__(self, response: AsyncHttpResponse):
super().__init__()
self._response = response
self._bytes_iterator: AsyncIterator[bytes] = response.iter_bytes()
def __aiter__(self) -> Any:
return self
async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate":
while self._queue.empty() and not self._done:
self._done = await self._read_next_block_async()
if self._queue.empty():
raise StopAsyncIteration
return self._queue.get()
async def _read_next_block_async(self) -> bool:
if self._ENABLE_CLASS_LOGS:
logger.debug("[Reading next block]")
try:
element = await self._bytes_iterator.__anext__()
except StopAsyncIteration:
await self.aclose()
return True
return self._deserialize_and_add_to_queue(element)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore
await self.aclose()
[docs]
async def aclose(self) -> None:
await self._response.close()
__all__: List[str] = [
"AssistantMessage",
"AsyncStreamingChatCompletions",
"ChatCompletions",
"ChatRequestMessage",
"EmbeddingsResult",
"ImageEmbeddingInput",
"ImageUrl",
"InputAudio",
"StreamingChatCompletions",
"SystemMessage",
"ToolMessage",
"UserMessage",
"DeveloperMessage",
] # Add all objects you want publicly available to users at this package level
def patch_sdk():
"""Do not remove from this file.
`patch_sdk` is a last resort escape hatch that allows you to do customizations
you can't accomplish using the techniques described in
https://aka.ms/azsdk/python/dpcodegen/python/customize
"""