feat: Refactor extra_body handling and update model error handling

- Changed the default value of `extra_body` from an empty dictionary to `None` in `ImageLabs` and `PollinationsAI` classes.
- Added a check to initialize `extra_body` to an empty dictionary if it is `None` in the `ImageLabs` class.
- Removed the `extra_image_models` list from the `PollinationsAI` class.
- Updated the way image models are combined in the `PollinationsAI` class to avoid duplicates.
- Changed the error handling for unsupported models from `ModelNotSupportedError` to `ModelNotFoundError` in multiple classes including `OpenaiChat`, `HuggingFaceAPI`, and `HuggingFaceInference`.
- Updated the `save_response_media` function to handle both string and bytes responses.
- Adjusted the handling of audio data in the `PollinationsAI` class to ensure proper processing of audio responses.
This commit is contained in:
hlohaus 2025-06-12 02:29:41 +02:00
parent 15d8318ab7
commit bf4ed09ab9
20 changed files with 166 additions and 152 deletions

View file

@ -36,9 +36,11 @@ class ImageLabs(AsyncGeneratorProvider, ProviderModelMixin):
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
extra_body: dict = {},
extra_body: dict = None,
**kwargs
) -> AsyncResult:
if extra_body is None:
extra_body = {}
extra_body = use_aspect_ratio({
"width": width,
"height": height,

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import random
import json
import uuid
import sys
import asyncio
@ -14,7 +13,7 @@ from ..tools.media import merge_media
from ..image import to_bytes, is_accepted_format
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import get_last_user_message
from ..errors import ModelNotFoundError
from ..errors import ModelNotFoundError, ResponseError
from .. import debug
class LegacyLMArena(AsyncGeneratorProvider, ProviderModelMixin):
@ -460,6 +459,8 @@ class LegacyLMArena(AsyncGeneratorProvider, ProviderModelMixin):
content = data
if content:
if "**NETWORK ERROR DUE TO HIGH TRAFFIC." in content:
raise ResponseError(data)
# Clean up content
if isinstance(content, str):
if content.endswith(""):

View file

@ -19,7 +19,7 @@ from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from ..image.copy_images import save_response_media
from ..image import use_aspect_ratio
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse, Reasoning, TitleGeneration, SuggestedFollowups
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse, Reasoning, TitleGeneration, SuggestedFollowups, ProviderInfo
from ..tools.media import render_messages
from ..constants import STATIC_URL
from .. import debug
@ -84,7 +84,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
text_models = [default_model, "evil"]
image_models = [default_image_model]
audio_models = {default_audio_model: []}
extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "dall-e-3", "turbo"]
vision_models = [default_vision_model, "gpt-4o-mini", "openai", "openai-large", "openai-reasoning", "searchgpt"]
_models_loaded = False
# https://github.com/pollinations/pollinations/blob/master/text.pollinations.ai/generateTextPortkey.js#L15
@ -133,6 +132,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
### Image Models ###
"sdxl-turbo": "turbo",
"gpt-image": "gptimage",
"flux-pro": "flux",
"flux-dev": "flux",
"flux-schnell": "flux"
}
@classmethod
@ -164,14 +166,14 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
new_image_models = []
# Combine image models without duplicates
all_image_models = [cls.default_image_model] # Start with default model
image_models = [cls.default_image_model] # Start with default model
# Add extra image models if not already in the list
for model in cls.extra_image_models + new_image_models:
if model not in all_image_models:
all_image_models.append(model)
for model in new_image_models:
if model not in image_models:
image_models.append(model)
cls.image_models = all_image_models
cls.image_models = image_models
text_response = requests.get("https://text.pollinations.ai/models")
text_response.raise_for_status()
@ -194,19 +196,19 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
cls.vision_models.append(alias)
# Create a set of unique text models starting with default model
unique_text_models = cls.text_models.copy()
text_models = cls.text_models.copy()
# Add models from vision_models
unique_text_models.extend(cls.vision_models)
text_models.extend(cls.vision_models)
# Add models from the API response
for model in models:
model_name = model.get("name")
if model_name and "input_modalities" in model and "text" in model["input_modalities"]:
unique_text_models.append(model_name)
text_models.append(model_name)
# Convert to list and update text_models
cls.text_models = list(dict.fromkeys(unique_text_models))
cls.text_models = list(dict.fromkeys(text_models))
cls._models_loaded = True
@ -243,10 +245,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
stream: bool = True,
proxy: str = None,
cache: bool = False,
cache: bool = None,
referrer: str = STATIC_URL,
api_key: str = None,
extra_body: dict = {},
extra_body: dict = None,
# Image generation parameters
prompt: str = None,
aspect_ratio: str = "1:1",
@ -268,6 +270,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice", "modalities", "audio"],
**kwargs
) -> AsyncResult:
if cache is None:
cache = kwargs.get("action") == "next"
if extra_body is None:
extra_body = {}
# Load model list
cls.get_models()
if not model:
@ -363,8 +369,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"safe": str(safe).lower(),
}, aspect_ratio)
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in params.items() if v is not None)
prompt = quote_plus(prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
url = f"{cls.image_api_endpoint}prompt/{prompt}?{query}"
encoded_prompt = prompt
if model == "gptimage" and aspect_ratio != "1:1":
encoded_prompt = f"{encoded_prompt} aspect-ratio: {aspect_ratio}"
encoded_prompt = quote_plus(encoded_prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
url = f"{cls.image_api_endpoint}prompt/{encoded_prompt}?{query}"
def get_image_url(i: int, seed: Optional[int] = None):
if i == 0:
if not cache and seed is None:
@ -374,10 +383,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
return f"{url}&seed={seed}" if seed else url
headers = {"referer": referrer}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
headers["authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
responses = set()
responses.add(Reasoning(status=f"Generating {n} {'image' if n == 1 else 'images'}..."))
responses.add(Reasoning(status=f"Generate {n} {'image' if n == 1 else 'images'}..."))
finished = 0
start = time.time()
async def get_image(responses: set, i: int, seed: Optional[int] = None):
@ -386,8 +395,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
try:
await raise_for_status(response)
except Exception as e:
if response.status == 500:
responses.add(e)
return
debug.error(f"Error fetching image: {e}")
responses.add(ImageResponse(str(response.url), prompt))
responses.add(ImageResponse(str(response.url), prompt, {"headers": headers}))
finished += 1
responses.add(Reasoning(status=f"Image {finished}/{n} generated in {time.time() - start:.2f}s"))
tasks = []
@ -395,7 +407,12 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
tasks.append(asyncio.create_task(get_image(responses, i, seed)))
while finished < n or len(responses) > 0:
while len(responses) > 0:
yield responses.pop()
item = responses.pop()
if isinstance(item, Exception):
for task in tasks:
task.cancel()
raise item
yield item
await asyncio.sleep(0.1)
await asyncio.gather(*tasks)
@ -424,14 +441,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
seed = random.randint(0, 2**32)
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
if model in cls.audio_models:
if "audio" in kwargs and kwargs.get("audio", {}).get("voice") is None:
kwargs["audio"]["voice"] = cls.audio_models[model][0]
url = cls.text_api_endpoint
stream = False
else:
url = cls.openai_endpoint
extra_body.update({param: kwargs[param] for param in extra_parameters if param in kwargs})
if model in cls.audio_models:
if "audio" in extra_body and extra_body.get("audio", {}).get("voice") is None:
kwargs["audio"]["voice"] = cls.audio_models[model][0]
elif "audio" not in extra_body:
extra_body["audio"] = {"voice": cls.audio_models[model][0]}
if extra_body.get("audio", {}).get("format") is None:
extra_body["audio"]["format"] = "mp3"
if "modalities" not in extra_body:
extra_body["modalities"] = ["text", "audio"]
stream = False
data = filter_none(
messages=list(render_messages(messages, media)),
model=model,
@ -447,19 +467,23 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
)
headers = {"referer": referrer}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
async with session.post(url, json=data, headers=headers) as response:
if response.status == 400:
debug.error(f"Error: 400 - Bad Request: {data}")
headers["authorization"] = f"Bearer {api_key}"
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
if response.status in (400, 500):
debug.error(f"Error: {response.status} - Bad Request: {data}")
await raise_for_status(response)
if response.headers["content-type"].startswith("text/plain"):
yield await response.text()
return
elif response.headers["content-type"].startswith("text/event-stream"):
reasoning = False
model_returned = False
async for result in see_stream(response.content):
if "error" in result:
raise ResponseError(result["error"].get("message", result["error"]))
if not model_returned and result.get("model"):
yield ProviderInfo(**cls.get_dict(), model=result.get("model"))
model_returned = True
if result.get("usage") is not None:
yield Usage(**result["usage"])
choices = result.get("choices", [{}])
@ -478,15 +502,15 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if finish_reason:
yield FinishReason(finish_reason)
if reasoning:
yield Reasoning(status="Done")
yield Reasoning(status="")
if kwargs.get("action") == "next":
data = {
"model": "openai",
"messages": messages + FOLLOWUPS_DEVELOPER_MESSAGE,
"messages": [m for m in messages if m.get("role") == "user"] + FOLLOWUPS_DEVELOPER_MESSAGE,
"tool_choice": "required",
"tools": FOLLOWUPS_TOOLS
}
async with session.post(url, json=data, headers=headers) as response:
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
try:
await raise_for_status(response)
tool_calls = (await response.json()).get("choices", [{}])[0].get("message", {}).get("tool_calls", [])
@ -500,7 +524,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
debug.error("Error generating title and followups")
debug.error(e)
elif response.headers["content-type"].startswith("application/json"):
prompt = format_image_prompt(messages)
result = await response.json()
if result.get("model"):
yield ProviderInfo(**cls.get_dict(), model=result.get("model"))
if "choices" in result:
choice = result["choices"][0]
message = choice.get("message", {})
@ -509,6 +536,13 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
yield content
if "tool_calls" in message:
yield ToolCalls(message["tool_calls"])
audio = message.get("audio", {})
if "data" in audio:
async for chunk in save_response_media(audio["data"], prompt, [model, extra_body.get("audio", {}).get("voice")]):
yield chunk
if "transcript" in audio:
yield "\n\n"
yield audio["transcript"]
else:
raise ResponseError(result)
if result.get("usage") is not None:
@ -517,6 +551,5 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if finish_reason:
yield FinishReason(finish_reason)
else:
async for chunk in save_response_media(response, format_image_prompt(messages), [model, extra_body.get("audio", {}).get("voice")]):
async for chunk in save_response_media(response, prompt, [model, extra_body.get("audio", {}).get("voice")]):
yield chunk
return

View file

@ -14,23 +14,20 @@ class PollinationsImage(PollinationsAI):
default_vision_model = None
default_image_model = default_model
audio_models = {}
image_models = [default_image_model] # Default models
_models_loaded = False # Add a checkbox for synchronization
@classmethod
def get_models(cls, **kwargs):
if not cls._models_loaded:
# Calling the parent method to load models
super().get_models()
# Combine models from the parent class and additional ones
all_image_models = list(dict.fromkeys(
cls.image_models +
PollinationsAI.image_models +
cls.extra_image_models
))
cls.image_models = all_image_models
cls._models_loaded = True
return cls.image_models
PollinationsAI.get_models()
cls.image_models = PollinationsAI.image_models
cls.models = cls.image_models
return cls.models
@classmethod
def get_grouped_models(cls) -> dict[str, list[str]]:
PollinationsAI.get_models()
return [
{"group": "Image Generation", "models": PollinationsAI.image_models},
]
@classmethod
async def create_async_generator(

View file

@ -1,34 +1,27 @@
from __future__ import annotations
try:
has_openaifm = True
except ImportError:
has_openaifm = False
from aiohttp import ClientSession
from urllib.parse import urlencode
import json
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_last_message
from ..helper import get_last_user_message, get_system_prompt
from ...image.copy_images import save_response_media
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...requests import DEFAULT_HEADERS
class OpenAIFM(AsyncGeneratorProvider, ProviderModelMixin):
label = "OpenAI.fm"
url = "https://www.openai.fm"
api_endpoint = "https://www.openai.fm/api/generate"
working = has_openaifm
working = True
default_model = 'gpt-4o-mini-tts'
default_audio_model = default_model
default_voice = 'coral'
voices = ['alloy', 'ash', 'ballad', default_voice, 'echo', 'fable', 'onyx', 'nova', 'sage', 'shimmer', 'verse']
audio_models = {default_audio_model: voices}
models = [default_audio_model]
models = voices
friendly = """Affect/personality: A cheerful guide
@ -106,44 +99,24 @@ Emotion: Restrained enthusiasm for discoveries and findings, conveying intellect
audio: dict = {},
**kwargs
) -> AsyncResult:
# Retrieve parameters from the audio dictionary
voice = audio.get("voice", kwargs.get("voice", cls.default_voice))
instructions = audio.get("instructions", kwargs.get("instructions", cls.friendly))
instructions = audio.get("instructions", kwargs.get("instructions", get_system_prompt(messages) or cls.friendly))
headers = {
"accept": "*/*",
"accept-language": "en-US,en;q=0.9",
"cache-control": "no-cache",
"pragma": "no-cache",
"sec-fetch-dest": "audio",
"sec-fetch-mode": "no-cors",
"sec-fetch-site": "same-origin",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
"referer": cls.url
**DEFAULT_HEADERS,
"referer": f"{cls.url}/"
}
# Using prompts or formatting messages
text = get_last_message(messages, prompt)
text = get_last_user_message(messages, prompt)
params = {
"input": text,
"prompt": instructions,
"voice": voice
}
async with ClientSession(headers=headers) as session:
# Print the full URL with parameters
full_url = f"{cls.api_endpoint}?{urlencode(params)}"
async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session:
async with session.get(
cls.api_endpoint,
params=params,
proxy=proxy
params=params
) as response:
response.raise_for_status()
await raise_for_status(response)
async for chunk in save_response_media(response, text, [model, voice]):
yield chunk

View file

@ -7,7 +7,7 @@ import uuid
from ...typing import AsyncResult, Messages
from ...providers.response import Reasoning, JsonConversation
from ...requests.raise_for_status import raise_for_status
from ...errors import ModelNotSupportedError
from ...errors import ModelNotFoundError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_last_user_message
from ... import debug
@ -55,7 +55,7 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
try:
model = cls.get_model(model)
except ModelNotSupportedError:
except ModelNotFoundError:
pass
if conversation is None:
conversation = JsonConversation(session_hash=str(uuid.uuid4()).replace('-', ''))

View file

@ -23,7 +23,7 @@ from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...requests import get_nodriver
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotSupportedError
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
from ...tools.media import merge_media
@ -358,7 +358,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
debug.error(e)
try:
model = cls.get_model(model)
except ModelNotSupportedError:
except ModelNotFoundError:
pass
if conversation is None:
conversation = Conversation(None, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))

View file

@ -5,7 +5,7 @@ import requests
from ....providers.types import Messages
from ....typing import MediaListType
from ....requests import StreamSession, raise_for_status
from ....errors import ModelNotSupportedError, PaymentRequiredError
from ....errors import ModelNotFoundError, PaymentRequiredError
from ....providers.response import ProviderInfo
from ...template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_llama_model, default_vision_model, text_models
@ -34,7 +34,7 @@ class HuggingFaceAPI(OpenaiTemplate):
def get_model(cls, model: str, **kwargs) -> str:
try:
return super().get_model(model, **kwargs)
except ModelNotSupportedError:
except ModelNotFoundError:
return model
@classmethod
@ -87,14 +87,14 @@ class HuggingFaceAPI(OpenaiTemplate):
model = cls.get_model(model)
provider_mapping = await cls.get_mapping(model, api_key)
if not provider_mapping:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__}")
error = None
for provider_key in provider_mapping:
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
api_base = f"https://router.huggingface.co/{api_path}"
task = provider_mapping[provider_key]["task"]
if task != "conversational":
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
model = provider_mapping[provider_key]["providerId"]
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
# start = calculate_lenght(messages)

View file

@ -7,7 +7,7 @@ import requests
from ....typing import AsyncResult, Messages
from ...base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
from ....errors import ModelNotSupportedError, ResponseError
from ....errors import ModelNotFoundError, ResponseError
from ....requests import StreamSession, raise_for_status
from ....providers.response import FinishReason, ImageResponse
from ....image.copy_images import save_response_media
@ -58,7 +58,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
return cls.model_data[model]
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__}")
await raise_for_status(response)
cls.model_data[model] = await response.json()
return cls.model_data[model]
@ -77,7 +77,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
temperature: float = None,
prompt: str = None,
action: str = None,
extra_body: dict = {},
extra_body: dict = None,
seed: int = None,
aspect_ratio: str = None,
width: int = None,
@ -86,7 +86,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
try:
model = cls.get_model(model)
except ModelNotSupportedError:
except ModelNotFoundError:
pass
headers = {
'Accept-Encoding': 'gzip, deflate',
@ -94,6 +94,8 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
if extra_body is None:
extra_body = {}
image_extra_body = use_aspect_ratio({
"width": width,
"height": height,
@ -114,12 +116,12 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
}
async with session.post(provider_together_urls[model], json=data) as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
raise ModelNotFoundError(f"Model not found: {model}")
await raise_for_status(response)
result = await response.json()
yield ImageResponse([item["url"] for item in result["data"]], data["prompt"])
return
except ModelNotSupportedError:
except ModelNotFoundError:
pass
payload = None
params = {
@ -156,11 +158,11 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
params["seed"] = seed
payload = {"inputs": inputs, "parameters": params, "stream": stream}
else:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
raise ModelNotFoundError(f"Model not found: {model}")
await raise_for_status(response)
if stream:
first = True

View file

@ -7,7 +7,7 @@ import requests
from ....providers.types import Messages
from ....requests import StreamSession, raise_for_status
from ....errors import ModelNotSupportedError
from ....errors import ModelNotFoundError
from ....providers.helper import format_image_prompt
from ....providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ....providers.response import ProviderInfo, ImageResponse, VideoResponse, Reasoning
@ -98,7 +98,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
api_key: str = None,
extra_body: dict = {},
extra_body: dict = None,
prompt: str = None,
proxy: str = None,
timeout: int = 0,
@ -112,6 +112,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
resolution: str = "480p",
**kwargs
):
if extra_body is None:
extra_body = {}
selected_provider = None
if model and ":" in model:
model, selected_provider = model.split(":", 1)
@ -130,7 +132,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
}
provider_mapping = {**new_mapping, **provider_mapping}
if not provider_mapping:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__}")
async def generate(extra_body: dict, aspect_ratio: str = None):
last_response = None
for provider_key, provider in provider_mapping.items():
@ -142,7 +144,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
task = provider["task"]
provider_id = provider["providerId"]
if task not in cls.tasks:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
if aspect_ratio is None:
aspect_ratio = "1:1" if task == "text-to-image" else "16:9"
@ -209,7 +211,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
continue
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
raise ModelNotFoundError(f"Model not found: {model}")
await raise_for_status(response)
if response.headers.get("Content-Type", "").startswith("application/json"):
result = await response.json()

View file

@ -4,7 +4,7 @@ import random
from ....typing import AsyncResult, Messages
from ....providers.response import ImageResponse
from ....errors import ModelNotSupportedError, MissingAuthError
from ....errors import ModelNotFoundError, MissingAuthError
from ...base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .HuggingChat import HuggingChat
from .HuggingFaceAPI import HuggingFaceAPI
@ -58,7 +58,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
async for chunk in HuggingFaceMedia.create_async_generator(model, messages, **kwargs):
yield chunk
return
except ModelNotSupportedError:
except ModelNotFoundError:
pass
if model in cls.image_models:
if "api_key" not in kwargs:
@ -71,6 +71,6 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
try:
async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
yield chunk
except (ModelNotSupportedError, MissingAuthError):
except (ModelNotFoundError, MissingAuthError):
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
yield chunk

View file

@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
headers: dict = None,
impersonate: str = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
extra_body: dict = {},
extra_body: dict = None,
**kwargs
) -> AsyncResult:
if api_key is None and cls.api_key is not None:
@ -98,6 +98,8 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
return
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
if extra_body is None:
extra_body = {}
data = filter_none(
messages=list(render_messages(messages, media)),
model=model,

View file

@ -1,5 +1,6 @@
PACKAGE_NAME = "g4f"
GITHUB_REPOSITORY = "xtekky/gpt4free"
STATIC_DOMAIN = "gpt4free.github.io"
ORGANIZATION = "gpt4free"
GITHUB_REPOSITORY = f"xtekky/{ORGANIZATION}"
STATIC_DOMAIN = f"{ORGANIZATION}.github.io"
STATIC_URL = f"https://{STATIC_DOMAIN}/"
DIST_DIR = f"./{STATIC_DOMAIN}/dist"

View file

@ -22,9 +22,6 @@ class RetryNoProviderError(Exception):
class VersionNotFoundError(Exception):
...
class ModelNotSupportedError(Exception):
...
class MissingRequirementsError(Exception):
...

View file

@ -257,7 +257,7 @@ class Api:
}
def handle_provider(self, provider_handler, model):
if model:
if not getattr(provider_handler, "model", False):
return self._format_json("provider", {**provider_handler.get_dict(), "model": model})
return self._format_json("provider", provider_handler.get_dict())

View file

@ -25,7 +25,7 @@ EXTENSIONS_MAP: dict[str, str] = {
"gif": "image/gif",
"webp": "image/webp",
# Audio
"wav": "audio/x-wav",
"wav": "audio/wav",
"mp3": "audio/mpeg",
"flac": "audio/flac",
"opus": "audio/opus",

View file

@ -4,7 +4,7 @@ import os
import time
import asyncio
import hashlib
import re
import base64
from typing import AsyncIterator
from urllib.parse import quote, unquote
from aiohttp import ClientSession, ClientError
@ -54,24 +54,36 @@ def get_source_url(image: str, default: str = None) -> str:
return decoded_url
return default
async def save_response_media(response: StreamResponse, prompt: str, tags: list[str]) -> AsyncIterator:
async def save_response_media(response, prompt: str, tags: list[str]) -> AsyncIterator:
"""Save media from response to local file and return URL"""
content_type = response.headers["content-type"]
if isinstance(response, str):
response = base64.b64decode(response)
content_type = response.headers["content-type"] if hasattr(response, "headers") else "audio/mpeg"
extension = MEDIA_TYPE_MAP.get(content_type)
if extension is None:
raise ValueError(f"Unsupported media type: {content_type}")
filename = get_filename(tags, prompt, f".{extension}", prompt)
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
with open(target_path, 'wb') as f:
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
if isinstance(response, bytes):
f.write(response)
else:
if hasattr(response, "iter_content"):
iter_response = response.iter_content()
else:
iter_response = response.content.iter_any()
async for chunk in iter_response:
f.write(chunk)
# Base URL without request parameters
media_url = f"/media/{filename}"
# Save the original URL in the metadata, but not in the file path itself
source_url = str(response.url) if response.method == "GET" else None
source_url = None
if hasattr(response, "url") and response.method == "GET":
source_url = str(response.url)
if content_type.startswith("audio/"):
yield AudioResponse(media_url, text=prompt, source_url=source_url)

View file

@ -20,7 +20,7 @@ from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
from .response import BaseConversation, AuthResult
from .helper import concat_chunks
from ..cookies import get_cookies_dir
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError
from .. import debug
SAFE_PARAMETERS = [
@ -363,7 +363,7 @@ class ProviderModelMixin:
model = cls.model_aliases[model]
else:
if model not in cls.get_models(**kwargs) and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
cls.last_model = model
debug.last_model = model
return model

View file

@ -28,9 +28,12 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
content_type = response.headers.get("content-type", "")
if content_type.startswith("application/json"):
message = await response.json()
message = message.get("error", message)
if isinstance(message, dict):
error = message.get("error")
if isinstance(error, dict):
message = error.get("message")
message = message.get("message", message)
if isinstance(error, str):
message = f"{error}: {message}"
else:
message = (await response.text()).strip()
is_html = content_type.startswith("text/html") or message.startswith("<!DOCTYPE")

View file

@ -1,18 +1,13 @@
import codecs
import os
from setuptools import find_packages, setup
STATIC_HOST = "gpt4free.github.io"
current_dir = os.path.abspath(os.path.dirname(__file__))
here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, 'README.md'), encoding='utf-8') as fh:
long_description = '\n' + fh.read()
with open(os.path.join(current_dir, 'README.md')) as f:
long_description = f.read()
long_description = long_description.replace("[!NOTE]", "")
long_description = long_description.replace("(docs/images/", f"(https://{STATIC_HOST}/docs/images/")
long_description = long_description.replace("(docs/", f"(https://github.com/gpt4free/{STATIC_HOST}/blob/main/docs/")
INSTALL_REQUIRE = [
"requests",
@ -41,9 +36,6 @@ EXTRA_REQUIRE = {
"pywebview",
"plyer",
"setuptools",
"odfpy", # files
"ebooklib",
"openpyxl",
"markitdown[all]"
],
'slim': [
@ -58,7 +50,7 @@ EXTRA_REQUIRE = {
"fastapi", # api
"uvicorn", # api
"python-multipart",
"markitdown[pdf, docx, pptx]"
"markitdown[all]"
],
"image": [
"pillow",
@ -91,9 +83,6 @@ EXTRA_REQUIRE = {
],
"files": [
"beautifulsoup4",
"odfpy",
"ebooklib",
"openpyxl",
"markitdown[all]"
]
}