feat: refactor provider create functions to class attributes and update calls

- Added `create_function` and `async_create_function` class attributes with default implementations in `base_provider.py` for `AbstractProvider`, `AsyncProvider`, and `AsyncGeneratorProvider`
- Updated `get_create_function` and `get_async_create_function` methods to return these class attributes
- Replaced calls to `provider.get_create_function()` and `provider.get_async_create_function()` with direct attribute access `provider.create_function` and `provider.async_create_function` across `g4f/__init__.py`, `g4f/client/__init__.py`, `g4f/providers/retry_provider.py`, and `g4f/tools/run_tools.py`
- Removed redundant `get_create_function` and `get_async_create_function` methods from `providers/base_provider.py` and `providers/types.py`
- Ensured all provider response calls now use the class attributes for creating completions asynchronously and synchronously as needed
This commit is contained in:
hlohaus 2025-06-12 12:45:55 +02:00
parent 5734a06193
commit 2befef988b
15 changed files with 142 additions and 111 deletions

View file

@ -10,10 +10,10 @@ from ..requests import StreamSession, raise_for_status
from ..errors import ModelNotFoundError
from .. import debug
class Together(OpenaiTemplate):
label = "Together"
url = "https://together.xyz"
login_url = "https://api.together.ai/"
api_base = "https://api.together.xyz/v1"
activation_endpoint = "https://www.codegeneration.ai/activate-v2"
models_endpoint = "https://api.together.xyz/v1/models"

View file

@ -682,7 +682,12 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
page = await browser.get(cls.url)
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
while not await page.evaluate("document.getElementById('prompt-textarea')?.id"):
textarea = None
while not textarea:
try:
textarea = await page.evaluate("document.getElementById('prompt-textarea')?.id")
except:
pass
await asyncio.sleep(1)
while not await page.evaluate("document.querySelector('[data-testid=\"send-button\"]')?.type"):
await asyncio.sleep(1)

View file

@ -15,6 +15,7 @@ from .. import debug
class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
label = "Puter.js"
parent = "Puter"
url = "https://docs.puter.com/playground"
login_url = "https://github.com/HeyPuter/puter-cli"
api_endpoint = "https://api.puter.com/drivers/call"

View file

@ -6,7 +6,7 @@ from ..helper import filter_none, format_media_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
from ...tools.media import render_messages
from ...errors import MissingAuthError, ResponseError
from ... import debug
@ -93,6 +93,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
data = await response.json()
cls.raise_error(data, response.status)
model = data.get("model")
if model:
yield ProviderInfo(**cls.get_dict(), model=model)
await raise_for_status(response)
yield ImageResponse([image["url"] for image in data["data"]], prompt)
return
@ -121,6 +124,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
data = await response.json()
cls.raise_error(data, response.status)
await raise_for_status(response)
model = data.get("model")
if model:
yield ProviderInfo(**cls.get_dict(), model=model)
choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip()
@ -134,8 +140,13 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
elif content_type.startswith("text/event-stream"):
await raise_for_status(response)
first = True
model_returned = False
async for data in response.sse():
cls.raise_error(data)
model = data.get("model")
if not model_returned and model:
yield ProviderInfo(**cls.get_dict(), model=model)
model_returned = True
choice = data["choices"][0]
if "content" in choice["delta"] and choice["delta"]["content"]:
delta = choice["delta"]["content"]

View file

@ -50,7 +50,7 @@ class ChatCompletion:
if ignore_stream:
kwargs["ignore_stream"] = True
result = provider.get_create_function()(model, messages, stream=stream, **kwargs)
result = provider.create_function(model, messages, stream=stream, **kwargs)
return result if stream or ignore_stream else concat_chunks(result)
@ -76,7 +76,7 @@ class ChatCompletion:
if ignore_stream:
kwargs["ignore_stream"] = True
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
result = provider.async_create_function(model, messages, stream=stream, **kwargs)
if not stream and not ignore_stream:
if hasattr(result, "__aiter__"):

View file

@ -490,7 +490,7 @@ class Api:
config.provider = provider
if config.provider is None:
config.provider = AppConfig.media_provider
if credentials is not None and credentials.credentials != "secret":
if config.api_key is None and credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials
try:
response = await self.client.images.generate(

View file

@ -16,7 +16,7 @@ class RequestConfig(BaseModel):
top_p: Optional[float] = None
max_tokens: Optional[int] = None
stop: Union[list[str], str, None] = None
api_key: Optional[str] = None
api_key: Optional[Union[str, dict[str, str]]] = None
api_base: str = None
web_search: Optional[bool] = None
proxy: Optional[str] = None
@ -70,6 +70,8 @@ class ImageGenerationConfig(BaseModel):
negative_prompt: Optional[str] = None
resolution: Optional[str] = None
audio: Optional[dict] = None
download_media: bool = True
@model_validator(mode='before')
def parse_size(cls, values):

View file

@ -375,7 +375,7 @@ class Completions:
kwargs["ignore_stream"] = True
response = iter_run_tools(
provider.get_create_function(),
provider.create_function,
model=model,
messages=messages,
stream=stream,
@ -462,7 +462,7 @@ class Images:
if isinstance(provider_handler, IterListProvider):
for provider in provider_handler.providers:
try:
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
if response is not None:
provider_name = provider.__name__
break
@ -485,21 +485,25 @@ class Images:
async def _generate_image_response(
self,
provider_handler,
provider_name,
provider_handler: ProviderType,
provider_name: str,
model: str,
prompt: str,
prompt_prefix: str = "Generate a image: ",
api_key: str = None,
**kwargs
) -> MediaResponse:
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
items: list[MediaResponse] = []
if isinstance(api_key, dict):
api_key = api_key.get(provider_handler.get_parent())
if hasattr(provider_handler, "create_async_generator"):
async for item in provider_handler.create_async_generator(
model,
messages,
stream=True,
prompt=prompt,
api_key=api_key,
**kwargs
):
if isinstance(item, (MediaResponse, AudioResponse)):
@ -510,6 +514,7 @@ class Images:
messages,
True,
prompt=prompt,
api_key=api_key,
**kwargs
):
if isinstance(item, (MediaResponse, AudioResponse)):

View file

@ -157,7 +157,7 @@ async def copy_media(
if media_type not in ("application/octet-stream", "binary/octet-stream"):
if media_type not in MEDIA_TYPE_MAP:
raise ValueError(f"Unsupported media type: {media_type}")
if not media_extension:
if target is None and not media_extension:
media_extension = f".{MEDIA_TYPE_MAP[media_type]}"
target_path = f"{target_path}{media_extension}"
with open(target_path, "wb") as f:

View file

@ -1,22 +1,32 @@
from __future__ import annotations
import re
from typing import Dict, List, Set, Optional, Tuple, Any
from ..typing import AsyncResult, Messages, MediaListType
from ..typing import AsyncResult, Messages, MediaListType, Union
from ..errors import ModelNotFoundError
from ..image import is_data_an_audio
from ..providers.retry_provider import IterListProvider
from ..providers.types import ProviderType
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
from ..Provider.hf_space import HuggingSpace
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI
from ..Provider import __map__
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
from ..Provider import EdgeTTS, gTTS, MarkItDown
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .. import Provider
from .. import models
MAIN_PROVIERS = [
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia
]
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS]
LABELS = {
"default": "Default",
"openai": "OpenAI: ChatGPT",
@ -31,6 +41,7 @@ LABELS = {
"mistral": "Mistral",
"PollinationsAI": "Pollinations AI",
"perplexity": "Perplexity Labs",
"openrouter": "OpenRouter",
"video": "Video Generation",
"image": "Image Generation",
"other": "Other Models",
@ -109,6 +120,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
groups["openai"].append(model)
added = True
# Check for openrouter models
elif model.startswith(("openrouter:")):
groups["openrouter"].append(model)
added = True
# Check for video models
elif model in cls.video_models:
groups["video"].append(model)
@ -159,7 +174,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
all_models = [cls.default_model] + list(model_with_providers.keys())
# Process special providers
for provider in [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]:
for provider in SPECIAL_PROVIDERS:
provider: ProviderType = provider
if not provider.working or provider.get_parent() in ignored:
continue
@ -214,7 +229,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
return name
# Process HAR providers
for provider in [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia]:
for provider in SPECIAL_PROVIDERS2:
if not provider.working or provider.get_parent() in ignored:
continue
new_models = provider.get_models()
@ -224,7 +239,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
# Add original models too, not just cleaned names
all_models.extend(new_models)
model_map = {clean_name(model): model for model in new_models}
model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models}
if not provider.model_aliases:
provider.model_aliases = {}
provider.model_aliases.update(model_map)
@ -262,6 +277,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
stream: bool = True,
media: MediaListType = None,
ignored: list[str] = [],
api_key: Union[str, dict[str, str]] = None,
**kwargs
) -> AsyncResult:
cls.get_models(ignored=ignored)
@ -284,7 +300,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
if "tools" in kwargs:
providers = [PollinationsAI]
elif "audio" in kwargs or "audio" in kwargs.get("modalities", []):
providers = [PollinationsAI, EdgeTTS, gTTS]
if kwargs.get("audio", {}).get("language") is None:
providers = [PollinationsAI, OpenAIFM, Gemini]
else:
providers = [PollinationsAI, OpenAIFM, EdgeTTS, gTTS]
elif has_audio:
providers = [PollinationsAI, Microsoft_Phi_4_Multimodal, MarkItDown]
elif has_image:
@ -297,18 +316,18 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
model = None
providers.append(provider)
else:
for provider in [
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia,
]:
extra_providers = []
if isinstance(api_key, dict):
for provider in api_key:
if provider in __map__ and __map__[provider] not in MAIN_PROVIERS:
extra_providers.append(__map__[provider])
for provider in MAIN_PROVIERS + extra_providers:
if provider.working:
if not model or model in provider.get_models() or model in provider.model_aliases:
providers.append(provider)
if model in models.__models__:
for provider in models.__models__[model][1]:
providers.append(provider)
providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored]
providers = list({provider.__name__: provider for provider in providers}.values())
@ -320,6 +339,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
messages,
stream=stream,
media=media,
api_key=api_key,
**kwargs
):
yield chunk

View file

@ -126,12 +126,30 @@ class AbstractProvider(BaseProvider):
)
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
def create_function(cls, *args, **kwargs) -> CreateResult:
"""
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_completion(*args, **kwargs)
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
"""
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_async(*args, **kwargs)
@classmethod
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
@ -264,14 +282,6 @@ class AsyncProvider(AbstractProvider):
"""
raise NotImplementedError()
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
class AsyncGeneratorProvider(AbstractProvider):
"""
Provides asynchronous generator functionality for streaming results.
@ -331,12 +341,17 @@ class AsyncGeneratorProvider(AbstractProvider):
raise NotImplementedError()
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
"""
Creates a completion using the synchronous method.
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_async_generator(*args, **kwargs)
class ProviderModelMixin:
default_model: str = None
@ -417,14 +432,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
return to_sync_generator(auth_result)
return asyncio.run(auth_result)
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
@classmethod
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
if auth_result is not None:

View file

@ -26,7 +26,6 @@ class IterListProvider(BaseRetryProvider):
self.shuffle = shuffle
self.working = True
self.last_provider: Type[BaseProvider] = None
self.add_api_key = False
def create_completion(
self,
@ -35,6 +34,7 @@ class IterListProvider(BaseRetryProvider):
stream: bool = False,
ignore_stream: bool = False,
ignored: list[str] = [],
api_key: str = None,
**kwargs,
) -> CreateResult:
"""
@ -55,8 +55,11 @@ class IterListProvider(BaseRetryProvider):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
extra_body = kwargs.copy()
if isinstance(api_key, dict):
extra_body["api_key"] = api_key.get(provider.get_parent())
try:
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
response = provider.create_function(model, messages, stream=stream, **extra_body)
for chunk in response:
if chunk:
yield chunk
@ -66,7 +69,7 @@ class IterListProvider(BaseRetryProvider):
return
except Exception as e:
exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
debug.error(f"{provider.__name__}:", e)
if started:
raise e
yield e
@ -92,12 +95,12 @@ class IterListProvider(BaseRetryProvider):
debug.log(f"Using {provider.__name__} provider" + (f" and {model} model" if model else ""))
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
extra_body = kwargs.copy()
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]:
extra_body["api_key"] = api_key
if isinstance(api_key, dict):
extra_body["api_key"] = api_key.get(provider.get_parent())
if conversation is not None and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
try:
response = provider.get_async_create_function()(model, messages, stream=stream, **extra_body)
response = provider.async_create_function(model, messages, stream=stream, **extra_body)
if hasattr(response, "__aiter__"):
async for chunk in response:
if isinstance(chunk, JsonConversation):
@ -118,18 +121,15 @@ class IterListProvider(BaseRetryProvider):
return
except Exception as e:
exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
debug.error(f"{provider.__name__}:", e)
if started:
raise e
yield e
raise_exceptions(exceptions)
def get_create_function(self) -> callable:
return self.create_completion
def get_async_create_function(self) -> callable:
return self.create_async_generator
create_function = create_completion
async_create_function = create_async_generator
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
@ -156,7 +156,6 @@ class RetryProvider(IterListProvider):
super().__init__(providers, shuffle)
self.single_provider_retry = single_provider_retry
self.max_retries = max_retries
self.add_api_key = True
def create_completion(
self,
@ -185,7 +184,7 @@ class RetryProvider(IterListProvider):
try:
if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
response = provider.create_function(model, messages, stream=stream, **kwargs)
for chunk in response:
yield chunk
if is_content(chunk):
@ -218,7 +217,7 @@ class RetryProvider(IterListProvider):
for attempt in range(self.max_retries):
try:
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
response = provider.async_create_function(model, messages, stream=stream, **kwargs)
if hasattr(response, "__aiter__"):
async for chunk in response:
yield chunk

View file

@ -45,7 +45,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
finish = None
chunks = []
has_usage = False
async for chunk in provider.get_async_create_function()(
async for chunk in provider.async_create_function(
model,
messages,
stream=stream,

View file

@ -25,26 +25,8 @@ class BaseProvider(ABC):
supports_message_history: bool = False
supports_system_message: bool = False
params: str
@abstractmethod
def get_create_function() -> callable:
"""
Get the create function for the provider.
Returns:
callable: The create function.
"""
raise NotImplementedError()
@abstractmethod
def get_async_create_function() -> callable:
"""
Get the async create function for the provider.
Returns:
callable: The create function.
"""
raise NotImplementedError()
create_function: callable
async_create_function: callable
@classmethod
def get_dict(cls) -> Dict[str, str]:

View file

@ -246,8 +246,7 @@ async def async_iter_run_tools(
kwargs.update(extra_kwargs)
# Generate response
create_function = provider.get_async_create_function()
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
async for chunk in response:
yield chunk