feat: refactor provider create functions to class attributes and update calls

- Added `create_function` and `async_create_function` class attributes with default implementations in `base_provider.py` for `AbstractProvider`, `AsyncProvider`, and `AsyncGeneratorProvider`
- Updated `get_create_function` and `get_async_create_function` methods to return these class attributes
- Replaced calls to `provider.get_create_function()` and `provider.get_async_create_function()` with direct attribute access `provider.create_function` and `provider.async_create_function` across `g4f/__init__.py`, `g4f/client/__init__.py`, `g4f/providers/retry_provider.py`, and `g4f/tools/run_tools.py`
- Removed redundant `get_create_function` and `get_async_create_function` methods from `providers/base_provider.py` and `providers/types.py`
- Ensured all provider response calls now use the class attributes for creating completions asynchronously and synchronously as needed
This commit is contained in:
hlohaus 2025-06-12 12:45:55 +02:00
parent 5734a06193
commit 2befef988b
15 changed files with 142 additions and 111 deletions

View file

@ -10,20 +10,20 @@ from ..requests import StreamSession, raise_for_status
from ..errors import ModelNotFoundError from ..errors import ModelNotFoundError
from .. import debug from .. import debug
class Together(OpenaiTemplate): class Together(OpenaiTemplate):
label = "Together" label = "Together"
url = "https://together.xyz" url = "https://together.xyz"
login_url = "https://api.together.ai/"
api_base = "https://api.together.xyz/v1" api_base = "https://api.together.xyz/v1"
activation_endpoint = "https://www.codegeneration.ai/activate-v2" activation_endpoint = "https://www.codegeneration.ai/activate-v2"
models_endpoint = "https://api.together.xyz/v1/models" models_endpoint = "https://api.together.xyz/v1/models"
working = True working = True
needs_auth = False needs_auth = False
supports_stream = True supports_stream = True
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True
default_model = 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8' default_model = 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
default_vision_model = default_model default_vision_model = default_model
default_image_model = 'black-forest-labs/FLUX.1.1-pro' default_image_model = 'black-forest-labs/FLUX.1.1-pro'
@ -43,7 +43,7 @@ class Together(OpenaiTemplate):
model_configs = {} # Store model configurations including stop tokens model_configs = {} # Store model configurations including stop tokens
_models_cached = False _models_cached = False
_api_key_cache = None _api_key_cache = None
model_aliases = { model_aliases = {
### Models Chat/Language ### ### Models Chat/Language ###
# meta-llama # meta-llama

View file

@ -682,7 +682,12 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request) page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
page = await browser.get(cls.url) page = await browser.get(cls.url)
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True) user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
while not await page.evaluate("document.getElementById('prompt-textarea')?.id"): textarea = None
while not textarea:
try:
textarea = await page.evaluate("document.getElementById('prompt-textarea')?.id")
except:
pass
await asyncio.sleep(1) await asyncio.sleep(1)
while not await page.evaluate("document.querySelector('[data-testid=\"send-button\"]')?.type"): while not await page.evaluate("document.querySelector('[data-testid=\"send-button\"]')?.type"):
await asyncio.sleep(1) await asyncio.sleep(1)

View file

@ -15,6 +15,7 @@ from .. import debug
class PuterJS(AsyncGeneratorProvider, ProviderModelMixin): class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
label = "Puter.js" label = "Puter.js"
parent = "Puter"
url = "https://docs.puter.com/playground" url = "https://docs.puter.com/playground"
login_url = "https://github.com/HeyPuter/puter-cli" login_url = "https://github.com/HeyPuter/puter-cli"
api_endpoint = "https://api.puter.com/drivers/call" api_endpoint = "https://api.puter.com/drivers/call"

View file

@ -6,7 +6,7 @@ from ..helper import filter_none, format_media_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, AsyncResult, Messages, MediaListType from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
from ...tools.media import render_messages from ...tools.media import render_messages
from ...errors import MissingAuthError, ResponseError from ...errors import MissingAuthError, ResponseError
from ... import debug from ... import debug
@ -93,6 +93,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response: async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
data = await response.json() data = await response.json()
cls.raise_error(data, response.status) cls.raise_error(data, response.status)
model = data.get("model")
if model:
yield ProviderInfo(**cls.get_dict(), model=model)
await raise_for_status(response) await raise_for_status(response)
yield ImageResponse([image["url"] for image in data["data"]], prompt) yield ImageResponse([image["url"] for image in data["data"]], prompt)
return return
@ -121,6 +124,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
data = await response.json() data = await response.json()
cls.raise_error(data, response.status) cls.raise_error(data, response.status)
await raise_for_status(response) await raise_for_status(response)
model = data.get("model")
if model:
yield ProviderInfo(**cls.get_dict(), model=model)
choice = data["choices"][0] choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]: if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip() yield choice["message"]["content"].strip()
@ -134,8 +140,13 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
elif content_type.startswith("text/event-stream"): elif content_type.startswith("text/event-stream"):
await raise_for_status(response) await raise_for_status(response)
first = True first = True
model_returned = False
async for data in response.sse(): async for data in response.sse():
cls.raise_error(data) cls.raise_error(data)
model = data.get("model")
if not model_returned and model:
yield ProviderInfo(**cls.get_dict(), model=model)
model_returned = True
choice = data["choices"][0] choice = data["choices"][0]
if "content" in choice["delta"] and choice["delta"]["content"]: if "content" in choice["delta"] and choice["delta"]["content"]:
delta = choice["delta"]["content"] delta = choice["delta"]["content"]

View file

@ -50,7 +50,7 @@ class ChatCompletion:
if ignore_stream: if ignore_stream:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
result = provider.get_create_function()(model, messages, stream=stream, **kwargs) result = provider.create_function(model, messages, stream=stream, **kwargs)
return result if stream or ignore_stream else concat_chunks(result) return result if stream or ignore_stream else concat_chunks(result)
@ -76,7 +76,7 @@ class ChatCompletion:
if ignore_stream: if ignore_stream:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs) result = provider.async_create_function(model, messages, stream=stream, **kwargs)
if not stream and not ignore_stream: if not stream and not ignore_stream:
if hasattr(result, "__aiter__"): if hasattr(result, "__aiter__"):

View file

@ -490,7 +490,7 @@ class Api:
config.provider = provider config.provider = provider
if config.provider is None: if config.provider is None:
config.provider = AppConfig.media_provider config.provider = AppConfig.media_provider
if credentials is not None and credentials.credentials != "secret": if config.api_key is None and credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials config.api_key = credentials.credentials
try: try:
response = await self.client.images.generate( response = await self.client.images.generate(

View file

@ -16,7 +16,7 @@ class RequestConfig(BaseModel):
top_p: Optional[float] = None top_p: Optional[float] = None
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stop: Union[list[str], str, None] = None stop: Union[list[str], str, None] = None
api_key: Optional[str] = None api_key: Optional[Union[str, dict[str, str]]] = None
api_base: str = None api_base: str = None
web_search: Optional[bool] = None web_search: Optional[bool] = None
proxy: Optional[str] = None proxy: Optional[str] = None
@ -70,6 +70,8 @@ class ImageGenerationConfig(BaseModel):
negative_prompt: Optional[str] = None negative_prompt: Optional[str] = None
resolution: Optional[str] = None resolution: Optional[str] = None
audio: Optional[dict] = None audio: Optional[dict] = None
download_media: bool = True
@model_validator(mode='before') @model_validator(mode='before')
def parse_size(cls, values): def parse_size(cls, values):

View file

@ -375,7 +375,7 @@ class Completions:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
response = iter_run_tools( response = iter_run_tools(
provider.get_create_function(), provider.create_function,
model=model, model=model,
messages=messages, messages=messages,
stream=stream, stream=stream,
@ -462,7 +462,7 @@ class Images:
if isinstance(provider_handler, IterListProvider): if isinstance(provider_handler, IterListProvider):
for provider in provider_handler.providers: for provider in provider_handler.providers:
try: try:
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs) response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
if response is not None: if response is not None:
provider_name = provider.__name__ provider_name = provider.__name__
break break
@ -485,21 +485,25 @@ class Images:
async def _generate_image_response( async def _generate_image_response(
self, self,
provider_handler, provider_handler: ProviderType,
provider_name, provider_name: str,
model: str, model: str,
prompt: str, prompt: str,
prompt_prefix: str = "Generate a image: ", prompt_prefix: str = "Generate a image: ",
api_key: str = None,
**kwargs **kwargs
) -> MediaResponse: ) -> MediaResponse:
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}] messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
items: list[MediaResponse] = [] items: list[MediaResponse] = []
if isinstance(api_key, dict):
api_key = api_key.get(provider_handler.get_parent())
if hasattr(provider_handler, "create_async_generator"): if hasattr(provider_handler, "create_async_generator"):
async for item in provider_handler.create_async_generator( async for item in provider_handler.create_async_generator(
model, model,
messages, messages,
stream=True, stream=True,
prompt=prompt, prompt=prompt,
api_key=api_key,
**kwargs **kwargs
): ):
if isinstance(item, (MediaResponse, AudioResponse)): if isinstance(item, (MediaResponse, AudioResponse)):
@ -510,6 +514,7 @@ class Images:
messages, messages,
True, True,
prompt=prompt, prompt=prompt,
api_key=api_key,
**kwargs **kwargs
): ):
if isinstance(item, (MediaResponse, AudioResponse)): if isinstance(item, (MediaResponse, AudioResponse)):

View file

@ -157,7 +157,7 @@ async def copy_media(
if media_type not in ("application/octet-stream", "binary/octet-stream"): if media_type not in ("application/octet-stream", "binary/octet-stream"):
if media_type not in MEDIA_TYPE_MAP: if media_type not in MEDIA_TYPE_MAP:
raise ValueError(f"Unsupported media type: {media_type}") raise ValueError(f"Unsupported media type: {media_type}")
if not media_extension: if target is None and not media_extension:
media_extension = f".{MEDIA_TYPE_MAP[media_type]}" media_extension = f".{MEDIA_TYPE_MAP[media_type]}"
target_path = f"{target_path}{media_extension}" target_path = f"{target_path}{media_extension}"
with open(target_path, "wb") as f: with open(target_path, "wb") as f:

View file

@ -1,22 +1,32 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Dict, List, Set, Optional, Tuple, Any from ..typing import AsyncResult, Messages, MediaListType, Union
from ..typing import AsyncResult, Messages, MediaListType
from ..errors import ModelNotFoundError from ..errors import ModelNotFoundError
from ..image import is_data_an_audio from ..image import is_data_an_audio
from ..providers.retry_provider import IterListProvider from ..providers.retry_provider import IterListProvider
from ..providers.types import ProviderType from ..providers.types import ProviderType
from ..Provider.needs_auth import OpenaiChat, CopilotAccount from ..Provider.needs_auth import OpenaiChat, CopilotAccount
from ..Provider.hf_space import HuggingSpace from ..Provider.hf_space import HuggingSpace
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI from ..Provider import __map__
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
from ..Provider import EdgeTTS, gTTS, MarkItDown from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .. import Provider from .. import Provider
from .. import models from .. import models
MAIN_PROVIERS = [
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia
]
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS]
LABELS = { LABELS = {
"default": "Default", "default": "Default",
"openai": "OpenAI: ChatGPT", "openai": "OpenAI: ChatGPT",
@ -31,6 +41,7 @@ LABELS = {
"mistral": "Mistral", "mistral": "Mistral",
"PollinationsAI": "Pollinations AI", "PollinationsAI": "Pollinations AI",
"perplexity": "Perplexity Labs", "perplexity": "Perplexity Labs",
"openrouter": "OpenRouter",
"video": "Video Generation", "video": "Video Generation",
"image": "Image Generation", "image": "Image Generation",
"other": "Other Models", "other": "Other Models",
@ -45,16 +56,16 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]: def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
unsorted_models = cls.get_models(ignored=ignored) unsorted_models = cls.get_models(ignored=ignored)
groups = {key: [] for key in LABELS.keys()} groups = {key: [] for key in LABELS.keys()}
# Always add default first # Always add default first
groups["default"].append("default") groups["default"].append("default")
for model in unsorted_models: for model in unsorted_models:
if model == "default": if model == "default":
continue # Already added continue # Already added
added = False added = False
# Check for PollinationsAI models (with prefix) # Check for PollinationsAI models (with prefix)
if model.startswith("PollinationsAI:"): if model.startswith("PollinationsAI:"):
groups["PollinationsAI"].append(model) groups["PollinationsAI"].append(model)
@ -109,6 +120,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"): elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
groups["openai"].append(model) groups["openai"].append(model)
added = True added = True
# Check for openrouter models
elif model.startswith(("openrouter:")):
groups["openrouter"].append(model)
added = True
# Check for video models # Check for video models
elif model in cls.video_models: elif model in cls.video_models:
groups["video"].append(model) groups["video"].append(model)
@ -117,7 +132,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower(): elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower():
groups["image"].append(model) groups["image"].append(model)
added = True added = True
# If not categorized, check for special cases then put in other # If not categorized, check for special cases then put in other
if not added: if not added:
# CodeLlama is Meta's model # CodeLlama is Meta's model
@ -128,7 +143,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
groups["phi"].append(model) groups["phi"].append(model)
else: else:
groups["other"].append(model) groups["other"].append(model)
return [ return [
{"group": LABELS[group], "models": names} for group, names in groups.items() {"group": LABELS[group], "models": names} for group, names in groups.items()
] ]
@ -157,9 +172,9 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1 model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
} }
all_models = [cls.default_model] + list(model_with_providers.keys()) all_models = [cls.default_model] + list(model_with_providers.keys())
# Process special providers # Process special providers
for provider in [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]: for provider in SPECIAL_PROVIDERS:
provider: ProviderType = provider provider: ProviderType = provider
if not provider.working or provider.get_parent() in ignored: if not provider.working or provider.get_parent() in ignored:
continue continue
@ -186,7 +201,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
cls.image_models.extend(arta_models) cls.image_models.extend(arta_models)
else: else:
all_models.extend(provider.get_models()) all_models.extend(provider.get_models())
# Update special model lists # Update special model lists
if hasattr(provider, 'image_models'): if hasattr(provider, 'image_models'):
cls.image_models.extend(provider.image_models) cls.image_models.extend(provider.image_models)
@ -194,7 +209,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
cls.vision_models.extend(provider.vision_models) cls.vision_models.extend(provider.vision_models)
if hasattr(provider, 'video_models'): if hasattr(provider, 'video_models'):
cls.video_models.extend(provider.video_models) cls.video_models.extend(provider.video_models)
# Clean model names function # Clean model names function
def clean_name(name: str) -> str: def clean_name(name: str) -> str:
name = name.split("/")[-1].split(":")[0].lower() name = name.split("/")[-1].split(":")[0].lower()
@ -212,24 +227,24 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
name = name.replace("llama3", "llama-3") name = name.replace("llama3", "llama-3")
name = name.replace("flux.1-", "flux-") name = name.replace("flux.1-", "flux-")
return name return name
# Process HAR providers # Process HAR providers
for provider in [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia]: for provider in SPECIAL_PROVIDERS2:
if not provider.working or provider.get_parent() in ignored: if not provider.working or provider.get_parent() in ignored:
continue continue
new_models = provider.get_models() new_models = provider.get_models()
if provider == HuggingFaceMedia: if provider == HuggingFaceMedia:
new_models = provider.video_models new_models = provider.video_models
# Add original models too, not just cleaned names # Add original models too, not just cleaned names
all_models.extend(new_models) all_models.extend(new_models)
model_map = {clean_name(model): model for model in new_models} model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models}
if not provider.model_aliases: if not provider.model_aliases:
provider.model_aliases = {} provider.model_aliases = {}
provider.model_aliases.update(model_map) provider.model_aliases.update(model_map)
all_models.extend(list(model_map.keys())) all_models.extend(list(model_map.keys()))
# Update special model lists with both original and cleaned names # Update special model lists with both original and cleaned names
if hasattr(provider, 'image_models'): if hasattr(provider, 'image_models'):
cls.image_models.extend(provider.image_models) cls.image_models.extend(provider.image_models)
@ -240,18 +255,18 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
if hasattr(provider, 'video_models'): if hasattr(provider, 'video_models'):
cls.video_models.extend(provider.video_models) cls.video_models.extend(provider.video_models)
cls.video_models.extend([clean_name(model) for model in provider.video_models]) cls.video_models.extend([clean_name(model) for model in provider.video_models])
# Process audio providers # Process audio providers
for provider in [Microsoft_Phi_4_Multimodal, PollinationsAI]: for provider in [Microsoft_Phi_4_Multimodal, PollinationsAI]:
if provider.working and provider.get_parent() not in ignored: if provider.working and provider.get_parent() not in ignored:
cls.audio_models.update(provider.audio_models) cls.audio_models.update(provider.audio_models)
# Update model counts # Update model counts
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)}) cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
# Deduplicate and store # Deduplicate and store
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models])) cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
return cls.models_storage[ignored_key] return cls.models_storage[ignored_key]
@classmethod @classmethod
@ -262,6 +277,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
stream: bool = True, stream: bool = True,
media: MediaListType = None, media: MediaListType = None,
ignored: list[str] = [], ignored: list[str] = [],
api_key: Union[str, dict[str, str]] = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
cls.get_models(ignored=ignored) cls.get_models(ignored=ignored)
@ -284,7 +300,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
if "tools" in kwargs: if "tools" in kwargs:
providers = [PollinationsAI] providers = [PollinationsAI]
elif "audio" in kwargs or "audio" in kwargs.get("modalities", []): elif "audio" in kwargs or "audio" in kwargs.get("modalities", []):
providers = [PollinationsAI, EdgeTTS, gTTS] if kwargs.get("audio", {}).get("language") is None:
providers = [PollinationsAI, OpenAIFM, Gemini]
else:
providers = [PollinationsAI, OpenAIFM, EdgeTTS, gTTS]
elif has_audio: elif has_audio:
providers = [PollinationsAI, Microsoft_Phi_4_Multimodal, MarkItDown] providers = [PollinationsAI, Microsoft_Phi_4_Multimodal, MarkItDown]
elif has_image: elif has_image:
@ -297,29 +316,30 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
model = None model = None
providers.append(provider) providers.append(provider)
else: else:
for provider in [ extra_providers = []
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, if isinstance(api_key, dict):
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena, for provider in api_key:
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia, if provider in __map__ and __map__[provider] not in MAIN_PROVIERS:
]: extra_providers.append(__map__[provider])
for provider in MAIN_PROVIERS + extra_providers:
if provider.working: if provider.working:
if not model or model in provider.get_models() or model in provider.model_aliases: if not model or model in provider.get_models() or model in provider.model_aliases:
providers.append(provider) providers.append(provider)
if model in models.__models__: if model in models.__models__:
for provider in models.__models__[model][1]: for provider in models.__models__[model][1]:
providers.append(provider) providers.append(provider)
providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored] providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored]
providers = list({provider.__name__: provider for provider in providers}.values()) providers = list({provider.__name__: provider for provider in providers}.values())
if len(providers) == 0: if len(providers) == 0:
raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.") raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.")
async for chunk in IterListProvider(providers).create_async_generator( async for chunk in IterListProvider(providers).create_async_generator(
model, model,
messages, messages,
stream=stream, stream=stream,
media=media, media=media,
api_key=api_key,
**kwargs **kwargs
): ):
yield chunk yield chunk

View file

@ -126,12 +126,30 @@ class AbstractProvider(BaseProvider):
) )
@classmethod @classmethod
def get_create_function(cls) -> callable: def create_function(cls, *args, **kwargs) -> CreateResult:
return cls.create_completion """
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_completion(*args, **kwargs)
@classmethod @classmethod
def get_async_create_function(cls) -> callable: def async_create_function(cls, *args, **kwargs) -> AsyncResult:
return cls.create_async """
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_async(*args, **kwargs)
@classmethod @classmethod
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]: def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
@ -264,14 +282,6 @@ class AsyncProvider(AbstractProvider):
""" """
raise NotImplementedError() raise NotImplementedError()
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
class AsyncGeneratorProvider(AbstractProvider): class AsyncGeneratorProvider(AbstractProvider):
""" """
Provides asynchronous generator functionality for streaming results. Provides asynchronous generator functionality for streaming results.
@ -331,12 +341,17 @@ class AsyncGeneratorProvider(AbstractProvider):
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def get_create_function(cls) -> callable: def async_create_function(cls, *args, **kwargs) -> AsyncResult:
return cls.create_completion """
Creates a completion using the synchronous method.
@classmethod Args:
def get_async_create_function(cls) -> callable: **kwargs: Additional keyword arguments.
return cls.create_async_generator
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_async_generator(*args, **kwargs)
class ProviderModelMixin: class ProviderModelMixin:
default_model: str = None default_model: str = None
@ -417,14 +432,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
return to_sync_generator(auth_result) return to_sync_generator(auth_result)
return asyncio.run(auth_result) return asyncio.run(auth_result)
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
@classmethod @classmethod
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None): def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
if auth_result is not None: if auth_result is not None:

View file

@ -26,7 +26,6 @@ class IterListProvider(BaseRetryProvider):
self.shuffle = shuffle self.shuffle = shuffle
self.working = True self.working = True
self.last_provider: Type[BaseProvider] = None self.last_provider: Type[BaseProvider] = None
self.add_api_key = False
def create_completion( def create_completion(
self, self,
@ -35,6 +34,7 @@ class IterListProvider(BaseRetryProvider):
stream: bool = False, stream: bool = False,
ignore_stream: bool = False, ignore_stream: bool = False,
ignored: list[str] = [], ignored: list[str] = [],
api_key: str = None,
**kwargs, **kwargs,
) -> CreateResult: ) -> CreateResult:
""" """
@ -55,8 +55,11 @@ class IterListProvider(BaseRetryProvider):
self.last_provider = provider self.last_provider = provider
debug.log(f"Using {provider.__name__} provider") debug.log(f"Using {provider.__name__} provider")
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model")) yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
extra_body = kwargs.copy()
if isinstance(api_key, dict):
extra_body["api_key"] = api_key.get(provider.get_parent())
try: try:
response = provider.get_create_function()(model, messages, stream=stream, **kwargs) response = provider.create_function(model, messages, stream=stream, **extra_body)
for chunk in response: for chunk in response:
if chunk: if chunk:
yield chunk yield chunk
@ -66,7 +69,7 @@ class IterListProvider(BaseRetryProvider):
return return
except Exception as e: except Exception as e:
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} {type(e).__name__}: {e}") debug.error(f"{provider.__name__}:", e)
if started: if started:
raise e raise e
yield e yield e
@ -92,12 +95,12 @@ class IterListProvider(BaseRetryProvider):
debug.log(f"Using {provider.__name__} provider" + (f" and {model} model" if model else "")) debug.log(f"Using {provider.__name__} provider" + (f" and {model} model" if model else ""))
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model")) yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
extra_body = kwargs.copy() extra_body = kwargs.copy()
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]: if isinstance(api_key, dict):
extra_body["api_key"] = api_key extra_body["api_key"] = api_key.get(provider.get_parent())
if conversation is not None and hasattr(conversation, provider.__name__): if conversation is not None and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__)) extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
try: try:
response = provider.get_async_create_function()(model, messages, stream=stream, **extra_body) response = provider.async_create_function(model, messages, stream=stream, **extra_body)
if hasattr(response, "__aiter__"): if hasattr(response, "__aiter__"):
async for chunk in response: async for chunk in response:
if isinstance(chunk, JsonConversation): if isinstance(chunk, JsonConversation):
@ -118,18 +121,15 @@ class IterListProvider(BaseRetryProvider):
return return
except Exception as e: except Exception as e:
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} {type(e).__name__}: {e}") debug.error(f"{provider.__name__}:", e)
if started: if started:
raise e raise e
yield e yield e
raise_exceptions(exceptions) raise_exceptions(exceptions)
def get_create_function(self) -> callable: create_function = create_completion
return self.create_completion async_create_function = create_async_generator
def get_async_create_function(self) -> callable:
return self.create_async_generator
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]: def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored] providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
@ -156,7 +156,6 @@ class RetryProvider(IterListProvider):
super().__init__(providers, shuffle) super().__init__(providers, shuffle)
self.single_provider_retry = single_provider_retry self.single_provider_retry = single_provider_retry
self.max_retries = max_retries self.max_retries = max_retries
self.add_api_key = True
def create_completion( def create_completion(
self, self,
@ -185,7 +184,7 @@ class RetryProvider(IterListProvider):
try: try:
if debug.logging: if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})") print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
response = provider.get_create_function()(model, messages, stream=stream, **kwargs) response = provider.create_function(model, messages, stream=stream, **kwargs)
for chunk in response: for chunk in response:
yield chunk yield chunk
if is_content(chunk): if is_content(chunk):
@ -218,7 +217,7 @@ class RetryProvider(IterListProvider):
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})") debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs) response = provider.async_create_function(model, messages, stream=stream, **kwargs)
if hasattr(response, "__aiter__"): if hasattr(response, "__aiter__"):
async for chunk in response: async for chunk in response:
yield chunk yield chunk

View file

@ -45,7 +45,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
finish = None finish = None
chunks = [] chunks = []
has_usage = False has_usage = False
async for chunk in provider.get_async_create_function()( async for chunk in provider.async_create_function(
model, model,
messages, messages,
stream=stream, stream=stream,

View file

@ -25,26 +25,8 @@ class BaseProvider(ABC):
supports_message_history: bool = False supports_message_history: bool = False
supports_system_message: bool = False supports_system_message: bool = False
params: str params: str
create_function: callable
@abstractmethod async_create_function: callable
def get_create_function() -> callable:
"""
Get the create function for the provider.
Returns:
callable: The create function.
"""
raise NotImplementedError()
@abstractmethod
def get_async_create_function() -> callable:
"""
Get the async create function for the provider.
Returns:
callable: The create function.
"""
raise NotImplementedError()
@classmethod @classmethod
def get_dict(cls) -> Dict[str, str]: def get_dict(cls) -> Dict[str, str]:

View file

@ -246,8 +246,7 @@ async def async_iter_run_tools(
kwargs.update(extra_kwargs) kwargs.update(extra_kwargs)
# Generate response # Generate response
create_function = provider.get_async_create_function() response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
async for chunk in response: async for chunk in response:
yield chunk yield chunk