Support audio model in Azure provider

This commit is contained in:
hlohaus 2025-07-12 19:41:23 +02:00
parent ebff7d51ab
commit c4b18df769
7 changed files with 39 additions and 19 deletions

View file

@ -4,7 +4,7 @@ import os
import json import json
from ...typing import Messages, AsyncResult from ...typing import Messages, AsyncResult
from ...errors import MissingAuthError from ...errors import MissingAuthError, ModelNotFoundError
from ..template import OpenaiTemplate from ..template import OpenaiTemplate
class Azure(OpenaiTemplate): class Azure(OpenaiTemplate):
@ -15,9 +15,20 @@ class Azure(OpenaiTemplate):
active_by_default = True active_by_default = True
login_url = "https://discord.gg/qXA4Wf4Fsm" login_url = "https://discord.gg/qXA4Wf4Fsm"
routes: dict[str, str] = {} routes: dict[str, str] = {}
audio_models = ["gpt-4o-mini-audio-preview"]
model_extra_body = {
"gpt-4o-mini-audio-preview": {
"audio": {
"voice": "alloy",
"format": "mp3"
},
"modalities": ["text", "audio"],
"stream": False
}
}
@classmethod @classmethod
def get_models(cls, **kwargs) -> list[str]: def get_models(cls, api_key: str = None, **kwargs) -> list[str]:
routes = os.environ.get("AZURE_ROUTES") routes = os.environ.get("AZURE_ROUTES")
if routes: if routes:
try: try:
@ -27,7 +38,7 @@ class Azure(OpenaiTemplate):
cls.routes = routes cls.routes = routes
if cls.routes: if cls.routes:
return list(cls.routes.keys()) return list(cls.routes.keys())
return super().get_models(**kwargs) return super().get_models(api_key=api_key, **kwargs)
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@ -40,6 +51,9 @@ class Azure(OpenaiTemplate):
) -> AsyncResult: ) -> AsyncResult:
if not model: if not model:
model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model) model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model)
if model in cls.model_extra_body:
for key, value in cls.model_extra_body[model].items():
kwargs.setdefault(key, value)
if not api_key: if not api_key:
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.") raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
if not api_endpoint: if not api_endpoint:
@ -47,7 +61,7 @@ class Azure(OpenaiTemplate):
cls.get_models() cls.get_models()
api_endpoint = cls.routes.get(model) api_endpoint = cls.routes.get(model)
if cls.routes and not api_endpoint: if cls.routes and not api_endpoint:
raise ValueError(f"No API endpoint found for model: {model}") raise ModelNotFoundError(f"No API endpoint found for model: {model}")
if not api_endpoint: if not api_endpoint:
api_endpoint = os.environ.get("AZURE_API_ENDPOINT") api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
try: try:

View file

@ -7,6 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErr
from ...typing import Union, AsyncResult, Messages, MediaListType from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...image import use_aspect_ratio from ...image import use_aspect_ratio
from ...image.copy_images import save_response_media
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
from ...tools.media import render_messages from ...tools.media import render_messages
from ...errors import MissingAuthError, ResponseError from ...errors import MissingAuthError, ResponseError
@ -62,7 +63,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
max_tokens: int = None, max_tokens: int = None,
top_p: float = None, top_p: float = None,
stop: Union[str, list[str]] = None, stop: Union[str, list[str]] = None,
stream: bool = False, stream: bool = None,
prompt: str = None, prompt: str = None,
headers: dict = None, headers: dict = None,
impersonate: str = None, impersonate: str = None,
@ -115,7 +116,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
stop=stop, stop=stop,
stream=stream, stream="audio" not in extra_parameters if stream is None else stream,
**extra_parameters, **extra_parameters,
**extra_body **extra_body
) )
@ -136,10 +137,18 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
yield Usage(**data["usage"]) yield Usage(**data["usage"])
if "choices" in data: if "choices" in data:
choice = next(iter(data["choices"]), None) choice = next(iter(data["choices"]), None)
if choice and "content" in choice["message"] and choice["message"]["content"]: message = choice.get("message", {})
yield choice["message"]["content"].strip() if choice and "content" in message and message["content"]:
if "tool_calls" in choice["message"]: yield message["content"].strip()
yield ToolCalls(choice["message"]["tool_calls"]) if "tool_calls" in message:
yield ToolCalls(message["tool_calls"])
audio = message.get("audio", {})
if "data" in audio:
async for chunk in save_response_media(audio["data"], prompt, [model, extra_body.get("audio", {}).get("voice")]):
yield chunk
if "transcript" in audio:
yield "\n\n"
yield audio["transcript"]
if choice and "finish_reason" in choice and choice["finish_reason"] is not None: if choice and "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"]) yield FinishReason(choice["finish_reason"])
return return

View file

@ -23,7 +23,7 @@ def convert_to_provider(provider: str) -> ProviderType:
def get_model_and_provider(model : Union[Model, str], def get_model_and_provider(model : Union[Model, str],
provider : Union[ProviderType, str, None], provider : Union[ProviderType, str, None],
stream : bool, stream : bool = False,
ignore_working: bool = False, ignore_working: bool = False,
ignore_stream: bool = False, ignore_stream: bool = False,
logging: bool = True, logging: bool = True,

View file

@ -149,7 +149,6 @@ class Api:
"model": model, "model": model,
"provider": provider, "provider": provider,
"messages": messages, "messages": messages,
"stream": True,
"ignore_stream": True, "ignore_stream": True,
**kwargs **kwargs
} }
@ -166,8 +165,6 @@ class Api:
try: try:
model, provider_handler = get_model_and_provider( model, provider_handler = get_model_and_provider(
kwargs.get("model"), provider, kwargs.get("model"), provider,
stream=True,
ignore_stream=True,
has_images="media" in kwargs, has_images="media" in kwargs,
) )
if "user" in kwargs: if "user" in kwargs:

View file

@ -112,7 +112,7 @@ def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
return "".join(( return "".join((
f"{int(time.time())}_", f"{int(time.time())}_",
f"{secure_filename(tags + alt)}_" if alt else secure_filename(tags), f"{secure_filename(tags + alt)}_" if alt else secure_filename(tags),
hashlib.sha256(image.encode()).hexdigest()[:16], hashlib.sha256(str(time.time()).encode() if image is None else image.encode()).hexdigest()[:16],
extension extension
)) ))

View file

@ -292,7 +292,7 @@ class AsyncGeneratorProvider(AbstractProvider):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool = True, stream: bool = None,
timeout: int = None, timeout: int = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
@ -312,7 +312,7 @@ class AsyncGeneratorProvider(AbstractProvider):
""" """
return to_sync_generator( return to_sync_generator(
cls.create_async_generator(model, messages, stream=stream, **kwargs), cls.create_async_generator(model, messages, stream=stream, **kwargs),
stream=stream, stream=stream is not False,
timeout=timeout timeout=timeout
) )
@ -321,7 +321,6 @@ class AsyncGeneratorProvider(AbstractProvider):
async def create_async_generator( async def create_async_generator(
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool = True,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
""" """

View file

@ -17,3 +17,4 @@ services:
restart: on-failure restart: on-failure
volumes: volumes:
- /var/win:/storage - /var/win:/storage
- ./:/data