mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
feat: Refactor extra_body handling and update model error handling
- Changed the default value of `extra_body` from an empty dictionary to `None` in `ImageLabs` and `PollinationsAI` classes. - Added a check to initialize `extra_body` to an empty dictionary if it is `None` in the `ImageLabs` class. - Removed the `extra_image_models` list from the `PollinationsAI` class. - Updated the way image models are combined in the `PollinationsAI` class to avoid duplicates. - Changed the error handling for unsupported models from `ModelNotSupportedError` to `ModelNotFoundError` in multiple classes including `OpenaiChat`, `HuggingFaceAPI`, and `HuggingFaceInference`. - Updated the `save_response_media` function to handle both string and bytes responses. - Adjusted the handling of audio data in the `PollinationsAI` class to ensure proper processing of audio responses.
This commit is contained in:
parent
15d8318ab7
commit
bf4ed09ab9
20 changed files with 166 additions and 152 deletions
|
|
@ -36,9 +36,11 @@ class ImageLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
aspect_ratio: str = "1:1",
|
||||
width: int = None,
|
||||
height: int = None,
|
||||
extra_body: dict = {},
|
||||
extra_body: dict = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
extra_body = use_aspect_ratio({
|
||||
"width": width,
|
||||
"height": height,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import random
|
||||
import json
|
||||
import uuid
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
|
||||
|
|
@ -14,7 +13,7 @@ from ..tools.media import merge_media
|
|||
from ..image import to_bytes, is_accepted_format
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import get_last_user_message
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..errors import ModelNotFoundError, ResponseError
|
||||
from .. import debug
|
||||
|
||||
class LegacyLMArena(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
|
@ -460,6 +459,8 @@ class LegacyLMArena(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
content = data
|
||||
|
||||
if content:
|
||||
if "**NETWORK ERROR DUE TO HIGH TRAFFIC." in content:
|
||||
raise ResponseError(data)
|
||||
# Clean up content
|
||||
if isinstance(content, str):
|
||||
if content.endswith("▌"):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from ..requests.raise_for_status import raise_for_status
|
|||
from ..requests.aiohttp import get_connector
|
||||
from ..image.copy_images import save_response_media
|
||||
from ..image import use_aspect_ratio
|
||||
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse, Reasoning, TitleGeneration, SuggestedFollowups
|
||||
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse, Reasoning, TitleGeneration, SuggestedFollowups, ProviderInfo
|
||||
from ..tools.media import render_messages
|
||||
from ..constants import STATIC_URL
|
||||
from .. import debug
|
||||
|
|
@ -84,7 +84,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
text_models = [default_model, "evil"]
|
||||
image_models = [default_image_model]
|
||||
audio_models = {default_audio_model: []}
|
||||
extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "dall-e-3", "turbo"]
|
||||
vision_models = [default_vision_model, "gpt-4o-mini", "openai", "openai-large", "openai-reasoning", "searchgpt"]
|
||||
_models_loaded = False
|
||||
# https://github.com/pollinations/pollinations/blob/master/text.pollinations.ai/generateTextPortkey.js#L15
|
||||
|
|
@ -133,6 +132,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
### Image Models ###
|
||||
"sdxl-turbo": "turbo",
|
||||
"gpt-image": "gptimage",
|
||||
"flux-pro": "flux",
|
||||
"flux-dev": "flux",
|
||||
"flux-schnell": "flux"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -164,14 +166,14 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
new_image_models = []
|
||||
|
||||
# Combine image models without duplicates
|
||||
all_image_models = [cls.default_image_model] # Start with default model
|
||||
image_models = [cls.default_image_model] # Start with default model
|
||||
|
||||
# Add extra image models if not already in the list
|
||||
for model in cls.extra_image_models + new_image_models:
|
||||
if model not in all_image_models:
|
||||
all_image_models.append(model)
|
||||
for model in new_image_models:
|
||||
if model not in image_models:
|
||||
image_models.append(model)
|
||||
|
||||
cls.image_models = all_image_models
|
||||
cls.image_models = image_models
|
||||
|
||||
text_response = requests.get("https://text.pollinations.ai/models")
|
||||
text_response.raise_for_status()
|
||||
|
|
@ -194,19 +196,19 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
cls.vision_models.append(alias)
|
||||
|
||||
# Create a set of unique text models starting with default model
|
||||
unique_text_models = cls.text_models.copy()
|
||||
text_models = cls.text_models.copy()
|
||||
|
||||
# Add models from vision_models
|
||||
unique_text_models.extend(cls.vision_models)
|
||||
text_models.extend(cls.vision_models)
|
||||
|
||||
# Add models from the API response
|
||||
for model in models:
|
||||
model_name = model.get("name")
|
||||
if model_name and "input_modalities" in model and "text" in model["input_modalities"]:
|
||||
unique_text_models.append(model_name)
|
||||
text_models.append(model_name)
|
||||
|
||||
# Convert to list and update text_models
|
||||
cls.text_models = list(dict.fromkeys(unique_text_models))
|
||||
cls.text_models = list(dict.fromkeys(text_models))
|
||||
|
||||
cls._models_loaded = True
|
||||
|
||||
|
|
@ -243,10 +245,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
messages: Messages,
|
||||
stream: bool = True,
|
||||
proxy: str = None,
|
||||
cache: bool = False,
|
||||
cache: bool = None,
|
||||
referrer: str = STATIC_URL,
|
||||
api_key: str = None,
|
||||
extra_body: dict = {},
|
||||
extra_body: dict = None,
|
||||
# Image generation parameters
|
||||
prompt: str = None,
|
||||
aspect_ratio: str = "1:1",
|
||||
|
|
@ -268,6 +270,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice", "modalities", "audio"],
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if cache is None:
|
||||
cache = kwargs.get("action") == "next"
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
# Load model list
|
||||
cls.get_models()
|
||||
if not model:
|
||||
|
|
@ -363,8 +369,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"safe": str(safe).lower(),
|
||||
}, aspect_ratio)
|
||||
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in params.items() if v is not None)
|
||||
prompt = quote_plus(prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
|
||||
url = f"{cls.image_api_endpoint}prompt/{prompt}?{query}"
|
||||
encoded_prompt = prompt
|
||||
if model == "gptimage" and aspect_ratio != "1:1":
|
||||
encoded_prompt = f"{encoded_prompt} aspect-ratio: {aspect_ratio}"
|
||||
encoded_prompt = quote_plus(encoded_prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
|
||||
url = f"{cls.image_api_endpoint}prompt/{encoded_prompt}?{query}"
|
||||
def get_image_url(i: int, seed: Optional[int] = None):
|
||||
if i == 0:
|
||||
if not cache and seed is None:
|
||||
|
|
@ -374,10 +383,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
return f"{url}&seed={seed}" if seed else url
|
||||
headers = {"referer": referrer}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
responses = set()
|
||||
responses.add(Reasoning(status=f"Generating {n} {'image' if n == 1 else 'images'}..."))
|
||||
responses.add(Reasoning(status=f"Generate {n} {'image' if n == 1 else 'images'}..."))
|
||||
finished = 0
|
||||
start = time.time()
|
||||
async def get_image(responses: set, i: int, seed: Optional[int] = None):
|
||||
|
|
@ -386,8 +395,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
try:
|
||||
await raise_for_status(response)
|
||||
except Exception as e:
|
||||
if response.status == 500:
|
||||
responses.add(e)
|
||||
return
|
||||
debug.error(f"Error fetching image: {e}")
|
||||
responses.add(ImageResponse(str(response.url), prompt))
|
||||
responses.add(ImageResponse(str(response.url), prompt, {"headers": headers}))
|
||||
finished += 1
|
||||
responses.add(Reasoning(status=f"Image {finished}/{n} generated in {time.time() - start:.2f}s"))
|
||||
tasks = []
|
||||
|
|
@ -395,7 +407,12 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
tasks.append(asyncio.create_task(get_image(responses, i, seed)))
|
||||
while finished < n or len(responses) > 0:
|
||||
while len(responses) > 0:
|
||||
yield responses.pop()
|
||||
item = responses.pop()
|
||||
if isinstance(item, Exception):
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
raise item
|
||||
yield item
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
|
@ -424,14 +441,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
seed = random.randint(0, 2**32)
|
||||
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
if model in cls.audio_models:
|
||||
if "audio" in kwargs and kwargs.get("audio", {}).get("voice") is None:
|
||||
kwargs["audio"]["voice"] = cls.audio_models[model][0]
|
||||
url = cls.text_api_endpoint
|
||||
stream = False
|
||||
else:
|
||||
url = cls.openai_endpoint
|
||||
extra_body.update({param: kwargs[param] for param in extra_parameters if param in kwargs})
|
||||
if model in cls.audio_models:
|
||||
if "audio" in extra_body and extra_body.get("audio", {}).get("voice") is None:
|
||||
kwargs["audio"]["voice"] = cls.audio_models[model][0]
|
||||
elif "audio" not in extra_body:
|
||||
extra_body["audio"] = {"voice": cls.audio_models[model][0]}
|
||||
if extra_body.get("audio", {}).get("format") is None:
|
||||
extra_body["audio"]["format"] = "mp3"
|
||||
if "modalities" not in extra_body:
|
||||
extra_body["modalities"] = ["text", "audio"]
|
||||
stream = False
|
||||
data = filter_none(
|
||||
messages=list(render_messages(messages, media)),
|
||||
model=model,
|
||||
|
|
@ -447,19 +467,23 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
)
|
||||
headers = {"referer": referrer}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status == 400:
|
||||
debug.error(f"Error: 400 - Bad Request: {data}")
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
|
||||
if response.status in (400, 500):
|
||||
debug.error(f"Error: {response.status} - Bad Request: {data}")
|
||||
await raise_for_status(response)
|
||||
if response.headers["content-type"].startswith("text/plain"):
|
||||
yield await response.text()
|
||||
return
|
||||
elif response.headers["content-type"].startswith("text/event-stream"):
|
||||
reasoning = False
|
||||
model_returned = False
|
||||
async for result in see_stream(response.content):
|
||||
if "error" in result:
|
||||
raise ResponseError(result["error"].get("message", result["error"]))
|
||||
if not model_returned and result.get("model"):
|
||||
yield ProviderInfo(**cls.get_dict(), model=result.get("model"))
|
||||
model_returned = True
|
||||
if result.get("usage") is not None:
|
||||
yield Usage(**result["usage"])
|
||||
choices = result.get("choices", [{}])
|
||||
|
|
@ -478,15 +502,15 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if finish_reason:
|
||||
yield FinishReason(finish_reason)
|
||||
if reasoning:
|
||||
yield Reasoning(status="Done")
|
||||
yield Reasoning(status="")
|
||||
if kwargs.get("action") == "next":
|
||||
data = {
|
||||
"model": "openai",
|
||||
"messages": messages + FOLLOWUPS_DEVELOPER_MESSAGE,
|
||||
"messages": [m for m in messages if m.get("role") == "user"] + FOLLOWUPS_DEVELOPER_MESSAGE,
|
||||
"tool_choice": "required",
|
||||
"tools": FOLLOWUPS_TOOLS
|
||||
}
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
|
||||
try:
|
||||
await raise_for_status(response)
|
||||
tool_calls = (await response.json()).get("choices", [{}])[0].get("message", {}).get("tool_calls", [])
|
||||
|
|
@ -500,7 +524,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
debug.error("Error generating title and followups")
|
||||
debug.error(e)
|
||||
elif response.headers["content-type"].startswith("application/json"):
|
||||
prompt = format_image_prompt(messages)
|
||||
result = await response.json()
|
||||
if result.get("model"):
|
||||
yield ProviderInfo(**cls.get_dict(), model=result.get("model"))
|
||||
if "choices" in result:
|
||||
choice = result["choices"][0]
|
||||
message = choice.get("message", {})
|
||||
|
|
@ -509,6 +536,13 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
yield content
|
||||
if "tool_calls" in message:
|
||||
yield ToolCalls(message["tool_calls"])
|
||||
audio = message.get("audio", {})
|
||||
if "data" in audio:
|
||||
async for chunk in save_response_media(audio["data"], prompt, [model, extra_body.get("audio", {}).get("voice")]):
|
||||
yield chunk
|
||||
if "transcript" in audio:
|
||||
yield "\n\n"
|
||||
yield audio["transcript"]
|
||||
else:
|
||||
raise ResponseError(result)
|
||||
if result.get("usage") is not None:
|
||||
|
|
@ -517,6 +551,5 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if finish_reason:
|
||||
yield FinishReason(finish_reason)
|
||||
else:
|
||||
async for chunk in save_response_media(response, format_image_prompt(messages), [model, extra_body.get("audio", {}).get("voice")]):
|
||||
async for chunk in save_response_media(response, prompt, [model, extra_body.get("audio", {}).get("voice")]):
|
||||
yield chunk
|
||||
return
|
||||
|
|
|
|||
|
|
@ -14,23 +14,20 @@ class PollinationsImage(PollinationsAI):
|
|||
default_vision_model = None
|
||||
default_image_model = default_model
|
||||
audio_models = {}
|
||||
image_models = [default_image_model] # Default models
|
||||
_models_loaded = False # Add a checkbox for synchronization
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs):
|
||||
if not cls._models_loaded:
|
||||
# Calling the parent method to load models
|
||||
super().get_models()
|
||||
# Combine models from the parent class and additional ones
|
||||
all_image_models = list(dict.fromkeys(
|
||||
cls.image_models +
|
||||
PollinationsAI.image_models +
|
||||
cls.extra_image_models
|
||||
))
|
||||
cls.image_models = all_image_models
|
||||
cls._models_loaded = True
|
||||
return cls.image_models
|
||||
PollinationsAI.get_models()
|
||||
cls.image_models = PollinationsAI.image_models
|
||||
cls.models = cls.image_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
def get_grouped_models(cls) -> dict[str, list[str]]:
|
||||
PollinationsAI.get_models()
|
||||
return [
|
||||
{"group": "Image Generation", "models": PollinationsAI.image_models},
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
|
|
|
|||
|
|
@ -1,34 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
has_openaifm = True
|
||||
except ImportError:
|
||||
has_openaifm = False
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from urllib.parse import urlencode
|
||||
import json
|
||||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_last_message
|
||||
from ..helper import get_last_user_message, get_system_prompt
|
||||
from ...image.copy_images import save_response_media
|
||||
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import get_connector
|
||||
from ...requests import DEFAULT_HEADERS
|
||||
|
||||
class OpenAIFM(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "OpenAI.fm"
|
||||
url = "https://www.openai.fm"
|
||||
api_endpoint = "https://www.openai.fm/api/generate"
|
||||
|
||||
working = has_openaifm
|
||||
working = True
|
||||
|
||||
default_model = 'gpt-4o-mini-tts'
|
||||
default_audio_model = default_model
|
||||
default_voice = 'coral'
|
||||
voices = ['alloy', 'ash', 'ballad', default_voice, 'echo', 'fable', 'onyx', 'nova', 'sage', 'shimmer', 'verse']
|
||||
audio_models = {default_audio_model: voices}
|
||||
models = [default_audio_model]
|
||||
|
||||
models = voices
|
||||
|
||||
friendly = """Affect/personality: A cheerful guide
|
||||
|
||||
|
|
@ -106,44 +99,24 @@ Emotion: Restrained enthusiasm for discoveries and findings, conveying intellect
|
|||
audio: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
|
||||
# Retrieve parameters from the audio dictionary
|
||||
voice = audio.get("voice", kwargs.get("voice", cls.default_voice))
|
||||
instructions = audio.get("instructions", kwargs.get("instructions", cls.friendly))
|
||||
|
||||
instructions = audio.get("instructions", kwargs.get("instructions", get_system_prompt(messages) or cls.friendly))
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"cache-control": "no-cache",
|
||||
"pragma": "no-cache",
|
||||
"sec-fetch-dest": "audio",
|
||||
"sec-fetch-mode": "no-cors",
|
||||
"sec-fetch-site": "same-origin",
|
||||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
|
||||
"referer": cls.url
|
||||
**DEFAULT_HEADERS,
|
||||
"referer": f"{cls.url}/"
|
||||
}
|
||||
|
||||
# Using prompts or formatting messages
|
||||
text = get_last_message(messages, prompt)
|
||||
|
||||
text = get_last_user_message(messages, prompt)
|
||||
params = {
|
||||
"input": text,
|
||||
"prompt": instructions,
|
||||
"voice": voice
|
||||
}
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
|
||||
# Print the full URL with parameters
|
||||
full_url = f"{cls.api_endpoint}?{urlencode(params)}"
|
||||
|
||||
async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session:
|
||||
async with session.get(
|
||||
cls.api_endpoint,
|
||||
params=params,
|
||||
proxy=proxy
|
||||
params=params
|
||||
) as response:
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
await raise_for_status(response)
|
||||
async for chunk in save_response_media(response, text, [model, voice]):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import uuid
|
|||
from ...typing import AsyncResult, Messages
|
||||
from ...providers.response import Reasoning, JsonConversation
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...errors import ModelNotSupportedError
|
||||
from ...errors import ModelNotFoundError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_last_user_message
|
||||
from ... import debug
|
||||
|
|
@ -55,7 +55,7 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
) -> AsyncResult:
|
||||
try:
|
||||
model = cls.get_model(model)
|
||||
except ModelNotSupportedError:
|
||||
except ModelNotFoundError:
|
||||
pass
|
||||
if conversation is None:
|
||||
conversation = JsonConversation(session_hash=str(uuid.uuid4()).replace('-', ''))
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from ...requests.raise_for_status import raise_for_status
|
|||
from ...requests import StreamSession
|
||||
from ...requests import get_nodriver
|
||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotSupportedError
|
||||
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
|
||||
from ...tools.media import merge_media
|
||||
|
|
@ -358,7 +358,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
debug.error(e)
|
||||
try:
|
||||
model = cls.get_model(model)
|
||||
except ModelNotSupportedError:
|
||||
except ModelNotFoundError:
|
||||
pass
|
||||
if conversation is None:
|
||||
conversation = Conversation(None, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import requests
|
|||
from ....providers.types import Messages
|
||||
from ....typing import MediaListType
|
||||
from ....requests import StreamSession, raise_for_status
|
||||
from ....errors import ModelNotSupportedError, PaymentRequiredError
|
||||
from ....errors import ModelNotFoundError, PaymentRequiredError
|
||||
from ....providers.response import ProviderInfo
|
||||
from ...template.OpenaiTemplate import OpenaiTemplate
|
||||
from .models import model_aliases, vision_models, default_llama_model, default_vision_model, text_models
|
||||
|
|
@ -34,7 +34,7 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
try:
|
||||
return super().get_model(model, **kwargs)
|
||||
except ModelNotSupportedError:
|
||||
except ModelNotFoundError:
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
|
|
@ -87,14 +87,14 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||
model = cls.get_model(model)
|
||||
provider_mapping = await cls.get_mapping(model, api_key)
|
||||
if not provider_mapping:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
error = None
|
||||
for provider_key in provider_mapping:
|
||||
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
|
||||
api_base = f"https://router.huggingface.co/{api_path}"
|
||||
task = provider_mapping[provider_key]["task"]
|
||||
if task != "conversational":
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
|
||||
model = provider_mapping[provider_key]["providerId"]
|
||||
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
|
||||
# start = calculate_lenght(messages)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import requests
|
|||
|
||||
from ....typing import AsyncResult, Messages
|
||||
from ...base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
|
||||
from ....errors import ModelNotSupportedError, ResponseError
|
||||
from ....errors import ModelNotFoundError, ResponseError
|
||||
from ....requests import StreamSession, raise_for_status
|
||||
from ....providers.response import FinishReason, ImageResponse
|
||||
from ....image.copy_images import save_response_media
|
||||
|
|
@ -58,7 +58,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
return cls.model_data[model]
|
||||
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
|
||||
if response.status == 404:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__}")
|
||||
await raise_for_status(response)
|
||||
cls.model_data[model] = await response.json()
|
||||
return cls.model_data[model]
|
||||
|
|
@ -77,7 +77,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
temperature: float = None,
|
||||
prompt: str = None,
|
||||
action: str = None,
|
||||
extra_body: dict = {},
|
||||
extra_body: dict = None,
|
||||
seed: int = None,
|
||||
aspect_ratio: str = None,
|
||||
width: int = None,
|
||||
|
|
@ -86,7 +86,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
) -> AsyncResult:
|
||||
try:
|
||||
model = cls.get_model(model)
|
||||
except ModelNotSupportedError:
|
||||
except ModelNotFoundError:
|
||||
pass
|
||||
headers = {
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
|
|
@ -94,6 +94,8 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
image_extra_body = use_aspect_ratio({
|
||||
"width": width,
|
||||
"height": height,
|
||||
|
|
@ -114,12 +116,12 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
}
|
||||
async with session.post(provider_together_urls[model], json=data) as response:
|
||||
if response.status == 404:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model}")
|
||||
raise ModelNotFoundError(f"Model not found: {model}")
|
||||
await raise_for_status(response)
|
||||
result = await response.json()
|
||||
yield ImageResponse([item["url"] for item in result["data"]], data["prompt"])
|
||||
return
|
||||
except ModelNotSupportedError:
|
||||
except ModelNotFoundError:
|
||||
pass
|
||||
payload = None
|
||||
params = {
|
||||
|
|
@ -156,11 +158,11 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
params["seed"] = seed
|
||||
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
||||
else:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
||||
|
||||
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
|
||||
if response.status == 404:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model}")
|
||||
raise ModelNotFoundError(f"Model not found: {model}")
|
||||
await raise_for_status(response)
|
||||
if stream:
|
||||
first = True
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import requests
|
|||
|
||||
from ....providers.types import Messages
|
||||
from ....requests import StreamSession, raise_for_status
|
||||
from ....errors import ModelNotSupportedError
|
||||
from ....errors import ModelNotFoundError
|
||||
from ....providers.helper import format_image_prompt
|
||||
from ....providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ....providers.response import ProviderInfo, ImageResponse, VideoResponse, Reasoning
|
||||
|
|
@ -98,7 +98,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
model: str,
|
||||
messages: Messages,
|
||||
api_key: str = None,
|
||||
extra_body: dict = {},
|
||||
extra_body: dict = None,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
timeout: int = 0,
|
||||
|
|
@ -112,6 +112,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
resolution: str = "480p",
|
||||
**kwargs
|
||||
):
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
selected_provider = None
|
||||
if model and ":" in model:
|
||||
model, selected_provider = model.split(":", 1)
|
||||
|
|
@ -130,7 +132,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
}
|
||||
provider_mapping = {**new_mapping, **provider_mapping}
|
||||
if not provider_mapping:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
async def generate(extra_body: dict, aspect_ratio: str = None):
|
||||
last_response = None
|
||||
for provider_key, provider in provider_mapping.items():
|
||||
|
|
@ -142,7 +144,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
task = provider["task"]
|
||||
provider_id = provider["providerId"]
|
||||
if task not in cls.tasks:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
|
||||
|
||||
if aspect_ratio is None:
|
||||
aspect_ratio = "1:1" if task == "text-to-image" else "16:9"
|
||||
|
|
@ -209,7 +211,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
|
||||
continue
|
||||
if response.status == 404:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model}")
|
||||
raise ModelNotFoundError(f"Model not found: {model}")
|
||||
await raise_for_status(response)
|
||||
if response.headers.get("Content-Type", "").startswith("application/json"):
|
||||
result = await response.json()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import random
|
|||
|
||||
from ....typing import AsyncResult, Messages
|
||||
from ....providers.response import ImageResponse
|
||||
from ....errors import ModelNotSupportedError, MissingAuthError
|
||||
from ....errors import ModelNotFoundError, MissingAuthError
|
||||
from ...base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .HuggingChat import HuggingChat
|
||||
from .HuggingFaceAPI import HuggingFaceAPI
|
||||
|
|
@ -58,7 +58,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
async for chunk in HuggingFaceMedia.create_async_generator(model, messages, **kwargs):
|
||||
yield chunk
|
||||
return
|
||||
except ModelNotSupportedError:
|
||||
except ModelNotFoundError:
|
||||
pass
|
||||
if model in cls.image_models:
|
||||
if "api_key" not in kwargs:
|
||||
|
|
@ -71,6 +71,6 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
try:
|
||||
async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
|
||||
yield chunk
|
||||
except (ModelNotSupportedError, MissingAuthError):
|
||||
except (ModelNotFoundError, MissingAuthError):
|
||||
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
|
||||
extra_body: dict = {},
|
||||
extra_body: dict = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if api_key is None and cls.api_key is not None:
|
||||
|
|
@ -98,6 +98,8 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
return
|
||||
|
||||
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
data = filter_none(
|
||||
messages=list(render_messages(messages, media)),
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
PACKAGE_NAME = "g4f"
|
||||
GITHUB_REPOSITORY = "xtekky/gpt4free"
|
||||
STATIC_DOMAIN = "gpt4free.github.io"
|
||||
ORGANIZATION = "gpt4free"
|
||||
GITHUB_REPOSITORY = f"xtekky/{ORGANIZATION}"
|
||||
STATIC_DOMAIN = f"{ORGANIZATION}.github.io"
|
||||
STATIC_URL = f"https://{STATIC_DOMAIN}/"
|
||||
DIST_DIR = f"./{STATIC_DOMAIN}/dist"
|
||||
|
|
|
|||
|
|
@ -22,9 +22,6 @@ class RetryNoProviderError(Exception):
|
|||
class VersionNotFoundError(Exception):
|
||||
...
|
||||
|
||||
class ModelNotSupportedError(Exception):
|
||||
...
|
||||
|
||||
class MissingRequirementsError(Exception):
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ class Api:
|
|||
}
|
||||
|
||||
def handle_provider(self, provider_handler, model):
|
||||
if model:
|
||||
if not getattr(provider_handler, "model", False):
|
||||
return self._format_json("provider", {**provider_handler.get_dict(), "model": model})
|
||||
return self._format_json("provider", provider_handler.get_dict())
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ EXTENSIONS_MAP: dict[str, str] = {
|
|||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
# Audio
|
||||
"wav": "audio/x-wav",
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"opus": "audio/opus",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import os
|
|||
import time
|
||||
import asyncio
|
||||
import hashlib
|
||||
import re
|
||||
import base64
|
||||
from typing import AsyncIterator
|
||||
from urllib.parse import quote, unquote
|
||||
from aiohttp import ClientSession, ClientError
|
||||
|
|
@ -54,24 +54,36 @@ def get_source_url(image: str, default: str = None) -> str:
|
|||
return decoded_url
|
||||
return default
|
||||
|
||||
async def save_response_media(response: StreamResponse, prompt: str, tags: list[str]) -> AsyncIterator:
|
||||
async def save_response_media(response, prompt: str, tags: list[str]) -> AsyncIterator:
|
||||
"""Save media from response to local file and return URL"""
|
||||
content_type = response.headers["content-type"]
|
||||
if isinstance(response, str):
|
||||
response = base64.b64decode(response)
|
||||
content_type = response.headers["content-type"] if hasattr(response, "headers") else "audio/mpeg"
|
||||
extension = MEDIA_TYPE_MAP.get(content_type)
|
||||
if extension is None:
|
||||
raise ValueError(f"Unsupported media type: {content_type}")
|
||||
|
||||
filename = get_filename(tags, prompt, f".{extension}", prompt)
|
||||
target_path = os.path.join(get_media_dir(), filename)
|
||||
ensure_media_dir()
|
||||
with open(target_path, 'wb') as f:
|
||||
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
|
||||
if isinstance(response, bytes):
|
||||
f.write(response)
|
||||
else:
|
||||
if hasattr(response, "iter_content"):
|
||||
iter_response = response.iter_content()
|
||||
else:
|
||||
iter_response = response.content.iter_any()
|
||||
async for chunk in iter_response:
|
||||
f.write(chunk)
|
||||
|
||||
# Base URL without request parameters
|
||||
media_url = f"/media/{filename}"
|
||||
|
||||
# Save the original URL in the metadata, but not in the file path itself
|
||||
source_url = str(response.url) if response.method == "GET" else None
|
||||
source_url = None
|
||||
if hasattr(response, "url") and response.method == "GET":
|
||||
source_url = str(response.url)
|
||||
|
||||
if content_type.startswith("audio/"):
|
||||
yield AudioResponse(media_url, text=prompt, source_url=source_url)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
|
|||
from .response import BaseConversation, AuthResult
|
||||
from .helper import concat_chunks
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError
|
||||
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
|
|
@ -363,7 +363,7 @@ class ProviderModelMixin:
|
|||
model = cls.model_aliases[model]
|
||||
else:
|
||||
if model not in cls.get_models(**kwargs) and cls.models:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||
cls.last_model = model
|
||||
debug.last_model = model
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -28,9 +28,12 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
|||
content_type = response.headers.get("content-type", "")
|
||||
if content_type.startswith("application/json"):
|
||||
message = await response.json()
|
||||
message = message.get("error", message)
|
||||
if isinstance(message, dict):
|
||||
error = message.get("error")
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message")
|
||||
message = message.get("message", message)
|
||||
if isinstance(error, str):
|
||||
message = f"{error}: {message}"
|
||||
else:
|
||||
message = (await response.text()).strip()
|
||||
is_html = content_type.startswith("text/html") or message.startswith("<!DOCTYPE")
|
||||
|
|
|
|||
19
setup.py
19
setup.py
|
|
@ -1,18 +1,13 @@
|
|||
import codecs
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
STATIC_HOST = "gpt4free.github.io"
|
||||
current_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
with codecs.open(os.path.join(here, 'README.md'), encoding='utf-8') as fh:
|
||||
long_description = '\n' + fh.read()
|
||||
with open(os.path.join(current_dir, 'README.md')) as f:
|
||||
long_description = f.read()
|
||||
|
||||
long_description = long_description.replace("[!NOTE]", "")
|
||||
long_description = long_description.replace("(docs/images/", f"(https://{STATIC_HOST}/docs/images/")
|
||||
long_description = long_description.replace("(docs/", f"(https://github.com/gpt4free/{STATIC_HOST}/blob/main/docs/")
|
||||
|
||||
INSTALL_REQUIRE = [
|
||||
"requests",
|
||||
|
|
@ -41,9 +36,6 @@ EXTRA_REQUIRE = {
|
|||
"pywebview",
|
||||
"plyer",
|
||||
"setuptools",
|
||||
"odfpy", # files
|
||||
"ebooklib",
|
||||
"openpyxl",
|
||||
"markitdown[all]"
|
||||
],
|
||||
'slim': [
|
||||
|
|
@ -58,7 +50,7 @@ EXTRA_REQUIRE = {
|
|||
"fastapi", # api
|
||||
"uvicorn", # api
|
||||
"python-multipart",
|
||||
"markitdown[pdf, docx, pptx]"
|
||||
"markitdown[all]"
|
||||
],
|
||||
"image": [
|
||||
"pillow",
|
||||
|
|
@ -91,9 +83,6 @@ EXTRA_REQUIRE = {
|
|||
],
|
||||
"files": [
|
||||
"beautifulsoup4",
|
||||
"odfpy",
|
||||
"ebooklib",
|
||||
"openpyxl",
|
||||
"markitdown[all]"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue