mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-05 18:20:35 -08:00
Support audio model in Azure provider
This commit is contained in:
parent
ebff7d51ab
commit
c4b18df769
7 changed files with 39 additions and 19 deletions
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ...typing import Messages, AsyncResult
|
from ...typing import Messages, AsyncResult
|
||||||
from ...errors import MissingAuthError
|
from ...errors import MissingAuthError, ModelNotFoundError
|
||||||
from ..template import OpenaiTemplate
|
from ..template import OpenaiTemplate
|
||||||
|
|
||||||
class Azure(OpenaiTemplate):
|
class Azure(OpenaiTemplate):
|
||||||
|
|
@ -15,9 +15,20 @@ class Azure(OpenaiTemplate):
|
||||||
active_by_default = True
|
active_by_default = True
|
||||||
login_url = "https://discord.gg/qXA4Wf4Fsm"
|
login_url = "https://discord.gg/qXA4Wf4Fsm"
|
||||||
routes: dict[str, str] = {}
|
routes: dict[str, str] = {}
|
||||||
|
audio_models = ["gpt-4o-mini-audio-preview"]
|
||||||
|
model_extra_body = {
|
||||||
|
"gpt-4o-mini-audio-preview": {
|
||||||
|
"audio": {
|
||||||
|
"voice": "alloy",
|
||||||
|
"format": "mp3"
|
||||||
|
},
|
||||||
|
"modalities": ["text", "audio"],
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls, **kwargs) -> list[str]:
|
def get_models(cls, api_key: str = None, **kwargs) -> list[str]:
|
||||||
routes = os.environ.get("AZURE_ROUTES")
|
routes = os.environ.get("AZURE_ROUTES")
|
||||||
if routes:
|
if routes:
|
||||||
try:
|
try:
|
||||||
|
|
@ -27,7 +38,7 @@ class Azure(OpenaiTemplate):
|
||||||
cls.routes = routes
|
cls.routes = routes
|
||||||
if cls.routes:
|
if cls.routes:
|
||||||
return list(cls.routes.keys())
|
return list(cls.routes.keys())
|
||||||
return super().get_models(**kwargs)
|
return super().get_models(api_key=api_key, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
|
|
@ -40,6 +51,9 @@ class Azure(OpenaiTemplate):
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if not model:
|
if not model:
|
||||||
model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model)
|
model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model)
|
||||||
|
if model in cls.model_extra_body:
|
||||||
|
for key, value in cls.model_extra_body[model].items():
|
||||||
|
kwargs.setdefault(key, value)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
|
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
|
||||||
if not api_endpoint:
|
if not api_endpoint:
|
||||||
|
|
@ -47,7 +61,7 @@ class Azure(OpenaiTemplate):
|
||||||
cls.get_models()
|
cls.get_models()
|
||||||
api_endpoint = cls.routes.get(model)
|
api_endpoint = cls.routes.get(model)
|
||||||
if cls.routes and not api_endpoint:
|
if cls.routes and not api_endpoint:
|
||||||
raise ValueError(f"No API endpoint found for model: {model}")
|
raise ModelNotFoundError(f"No API endpoint found for model: {model}")
|
||||||
if not api_endpoint:
|
if not api_endpoint:
|
||||||
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
|
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErr
|
||||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...image import use_aspect_ratio
|
from ...image import use_aspect_ratio
|
||||||
|
from ...image.copy_images import save_response_media
|
||||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
|
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
|
||||||
from ...tools.media import render_messages
|
from ...tools.media import render_messages
|
||||||
from ...errors import MissingAuthError, ResponseError
|
from ...errors import MissingAuthError, ResponseError
|
||||||
|
|
@ -62,7 +63,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
stop: Union[str, list[str]] = None,
|
stop: Union[str, list[str]] = None,
|
||||||
stream: bool = False,
|
stream: bool = None,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
headers: dict = None,
|
headers: dict = None,
|
||||||
impersonate: str = None,
|
impersonate: str = None,
|
||||||
|
|
@ -115,7 +116,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream="audio" not in extra_parameters if stream is None else stream,
|
||||||
**extra_parameters,
|
**extra_parameters,
|
||||||
**extra_body
|
**extra_body
|
||||||
)
|
)
|
||||||
|
|
@ -136,10 +137,18 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
yield Usage(**data["usage"])
|
yield Usage(**data["usage"])
|
||||||
if "choices" in data:
|
if "choices" in data:
|
||||||
choice = next(iter(data["choices"]), None)
|
choice = next(iter(data["choices"]), None)
|
||||||
if choice and "content" in choice["message"] and choice["message"]["content"]:
|
message = choice.get("message", {})
|
||||||
yield choice["message"]["content"].strip()
|
if choice and "content" in message and message["content"]:
|
||||||
if "tool_calls" in choice["message"]:
|
yield message["content"].strip()
|
||||||
yield ToolCalls(choice["message"]["tool_calls"])
|
if "tool_calls" in message:
|
||||||
|
yield ToolCalls(message["tool_calls"])
|
||||||
|
audio = message.get("audio", {})
|
||||||
|
if "data" in audio:
|
||||||
|
async for chunk in save_response_media(audio["data"], prompt, [model, extra_body.get("audio", {}).get("voice")]):
|
||||||
|
yield chunk
|
||||||
|
if "transcript" in audio:
|
||||||
|
yield "\n\n"
|
||||||
|
yield audio["transcript"]
|
||||||
if choice and "finish_reason" in choice and choice["finish_reason"] is not None:
|
if choice and "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||||
yield FinishReason(choice["finish_reason"])
|
yield FinishReason(choice["finish_reason"])
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ def convert_to_provider(provider: str) -> ProviderType:
|
||||||
|
|
||||||
def get_model_and_provider(model : Union[Model, str],
|
def get_model_and_provider(model : Union[Model, str],
|
||||||
provider : Union[ProviderType, str, None],
|
provider : Union[ProviderType, str, None],
|
||||||
stream : bool,
|
stream : bool = False,
|
||||||
ignore_working: bool = False,
|
ignore_working: bool = False,
|
||||||
ignore_stream: bool = False,
|
ignore_stream: bool = False,
|
||||||
logging: bool = True,
|
logging: bool = True,
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,6 @@ class Api:
|
||||||
"model": model,
|
"model": model,
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"stream": True,
|
|
||||||
"ignore_stream": True,
|
"ignore_stream": True,
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
@ -166,8 +165,6 @@ class Api:
|
||||||
try:
|
try:
|
||||||
model, provider_handler = get_model_and_provider(
|
model, provider_handler = get_model_and_provider(
|
||||||
kwargs.get("model"), provider,
|
kwargs.get("model"), provider,
|
||||||
stream=True,
|
|
||||||
ignore_stream=True,
|
|
||||||
has_images="media" in kwargs,
|
has_images="media" in kwargs,
|
||||||
)
|
)
|
||||||
if "user" in kwargs:
|
if "user" in kwargs:
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
|
||||||
return "".join((
|
return "".join((
|
||||||
f"{int(time.time())}_",
|
f"{int(time.time())}_",
|
||||||
f"{secure_filename(tags + alt)}_" if alt else secure_filename(tags),
|
f"{secure_filename(tags + alt)}_" if alt else secure_filename(tags),
|
||||||
hashlib.sha256(image.encode()).hexdigest()[:16],
|
hashlib.sha256(str(time.time()).encode() if image is None else image.encode()).hexdigest()[:16],
|
||||||
extension
|
extension
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -292,7 +292,7 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
stream: bool = True,
|
stream: bool = None,
|
||||||
timeout: int = None,
|
timeout: int = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> CreateResult:
|
) -> CreateResult:
|
||||||
|
|
@ -312,7 +312,7 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||||
"""
|
"""
|
||||||
return to_sync_generator(
|
return to_sync_generator(
|
||||||
cls.create_async_generator(model, messages, stream=stream, **kwargs),
|
cls.create_async_generator(model, messages, stream=stream, **kwargs),
|
||||||
stream=stream,
|
stream=stream is not False,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -321,7 +321,6 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
stream: bool = True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -17,3 +17,4 @@ services:
|
||||||
restart: on-failure
|
restart: on-failure
|
||||||
volumes:
|
volumes:
|
||||||
- /var/win:/storage
|
- /var/win:/storage
|
||||||
|
- ./:/data
|
||||||
Loading…
Add table
Add a link
Reference in a new issue