mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
feat: refactor provider create functions to class attributes and update calls
- Added `create_function` and `async_create_function` class attributes with default implementations in `base_provider.py` for `AbstractProvider`, `AsyncProvider`, and `AsyncGeneratorProvider` - Updated `get_create_function` and `get_async_create_function` methods to return these class attributes - Replaced calls to `provider.get_create_function()` and `provider.get_async_create_function()` with direct attribute access `provider.create_function` and `provider.async_create_function` across `g4f/__init__.py`, `g4f/client/__init__.py`, `g4f/providers/retry_provider.py`, and `g4f/tools/run_tools.py` - Removed redundant `get_create_function` and `get_async_create_function` methods from `providers/base_provider.py` and `providers/types.py` - Ensured all provider response calls now use the class attributes for creating completions asynchronously and synchronously as needed
This commit is contained in:
parent
5734a06193
commit
2befef988b
15 changed files with 142 additions and 111 deletions
|
|
@ -10,10 +10,10 @@ from ..requests import StreamSession, raise_for_status
|
|||
from ..errors import ModelNotFoundError
|
||||
from .. import debug
|
||||
|
||||
|
||||
class Together(OpenaiTemplate):
|
||||
label = "Together"
|
||||
url = "https://together.xyz"
|
||||
login_url = "https://api.together.ai/"
|
||||
api_base = "https://api.together.xyz/v1"
|
||||
activation_endpoint = "https://www.codegeneration.ai/activate-v2"
|
||||
models_endpoint = "https://api.together.xyz/v1/models"
|
||||
|
|
|
|||
|
|
@ -682,7 +682,12 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
|
||||
page = await browser.get(cls.url)
|
||||
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
|
||||
while not await page.evaluate("document.getElementById('prompt-textarea')?.id"):
|
||||
textarea = None
|
||||
while not textarea:
|
||||
try:
|
||||
textarea = await page.evaluate("document.getElementById('prompt-textarea')?.id")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
while not await page.evaluate("document.querySelector('[data-testid=\"send-button\"]')?.type"):
|
||||
await asyncio.sleep(1)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .. import debug
|
|||
|
||||
class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Puter.js"
|
||||
parent = "Puter"
|
||||
url = "https://docs.puter.com/playground"
|
||||
login_url = "https://github.com/HeyPuter/puter-cli"
|
||||
api_endpoint = "https://api.puter.com/drivers/call"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ..helper import filter_none, format_media_prompt
|
|||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
|
||||
from ...tools.media import render_messages
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ... import debug
|
||||
|
|
@ -93,6 +93,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
||||
data = await response.json()
|
||||
cls.raise_error(data, response.status)
|
||||
model = data.get("model")
|
||||
if model:
|
||||
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||
await raise_for_status(response)
|
||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||
return
|
||||
|
|
@ -121,6 +124,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
data = await response.json()
|
||||
cls.raise_error(data, response.status)
|
||||
await raise_for_status(response)
|
||||
model = data.get("model")
|
||||
if model:
|
||||
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["message"] and choice["message"]["content"]:
|
||||
yield choice["message"]["content"].strip()
|
||||
|
|
@ -134,8 +140,13 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
elif content_type.startswith("text/event-stream"):
|
||||
await raise_for_status(response)
|
||||
first = True
|
||||
model_returned = False
|
||||
async for data in response.sse():
|
||||
cls.raise_error(data)
|
||||
model = data.get("model")
|
||||
if not model_returned and model:
|
||||
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||
model_returned = True
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["delta"] and choice["delta"]["content"]:
|
||||
delta = choice["delta"]["content"]
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class ChatCompletion:
|
|||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
|
||||
result = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||
result = provider.create_function(model, messages, stream=stream, **kwargs)
|
||||
|
||||
return result if stream or ignore_stream else concat_chunks(result)
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ class ChatCompletion:
|
|||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
|
||||
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
||||
result = provider.async_create_function(model, messages, stream=stream, **kwargs)
|
||||
|
||||
if not stream and not ignore_stream:
|
||||
if hasattr(result, "__aiter__"):
|
||||
|
|
|
|||
|
|
@ -490,7 +490,7 @@ class Api:
|
|||
config.provider = provider
|
||||
if config.provider is None:
|
||||
config.provider = AppConfig.media_provider
|
||||
if credentials is not None and credentials.credentials != "secret":
|
||||
if config.api_key is None and credentials is not None and credentials.credentials != "secret":
|
||||
config.api_key = credentials.credentials
|
||||
try:
|
||||
response = await self.client.images.generate(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class RequestConfig(BaseModel):
|
|||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Union[list[str], str, None] = None
|
||||
api_key: Optional[str] = None
|
||||
api_key: Optional[Union[str, dict[str, str]]] = None
|
||||
api_base: str = None
|
||||
web_search: Optional[bool] = None
|
||||
proxy: Optional[str] = None
|
||||
|
|
@ -70,6 +70,8 @@ class ImageGenerationConfig(BaseModel):
|
|||
negative_prompt: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
audio: Optional[dict] = None
|
||||
download_media: bool = True
|
||||
|
||||
|
||||
@model_validator(mode='before')
|
||||
def parse_size(cls, values):
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ class Completions:
|
|||
kwargs["ignore_stream"] = True
|
||||
|
||||
response = iter_run_tools(
|
||||
provider.get_create_function(),
|
||||
provider.create_function,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
|
|
@ -462,7 +462,7 @@ class Images:
|
|||
if isinstance(provider_handler, IterListProvider):
|
||||
for provider in provider_handler.providers:
|
||||
try:
|
||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
|
||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
|
||||
if response is not None:
|
||||
provider_name = provider.__name__
|
||||
break
|
||||
|
|
@ -485,21 +485,25 @@ class Images:
|
|||
|
||||
async def _generate_image_response(
|
||||
self,
|
||||
provider_handler,
|
||||
provider_name,
|
||||
provider_handler: ProviderType,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt_prefix: str = "Generate a image: ",
|
||||
api_key: str = None,
|
||||
**kwargs
|
||||
) -> MediaResponse:
|
||||
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
|
||||
items: list[MediaResponse] = []
|
||||
if isinstance(api_key, dict):
|
||||
api_key = api_key.get(provider_handler.get_parent())
|
||||
if hasattr(provider_handler, "create_async_generator"):
|
||||
async for item in provider_handler.create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(item, (MediaResponse, AudioResponse)):
|
||||
|
|
@ -510,6 +514,7 @@ class Images:
|
|||
messages,
|
||||
True,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(item, (MediaResponse, AudioResponse)):
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ async def copy_media(
|
|||
if media_type not in ("application/octet-stream", "binary/octet-stream"):
|
||||
if media_type not in MEDIA_TYPE_MAP:
|
||||
raise ValueError(f"Unsupported media type: {media_type}")
|
||||
if not media_extension:
|
||||
if target is None and not media_extension:
|
||||
media_extension = f".{MEDIA_TYPE_MAP[media_type]}"
|
||||
target_path = f"{target_path}{media_extension}"
|
||||
with open(target_path, "wb") as f:
|
||||
|
|
|
|||
|
|
@ -1,22 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Set, Optional, Tuple, Any
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..typing import AsyncResult, Messages, MediaListType, Union
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..image import is_data_an_audio
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.types import ProviderType
|
||||
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
||||
from ..Provider.hf_space import HuggingSpace
|
||||
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI
|
||||
from ..Provider import __map__
|
||||
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
||||
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
|
||||
from ..Provider import EdgeTTS, gTTS, MarkItDown
|
||||
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
|
||||
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .. import Provider
|
||||
from .. import models
|
||||
|
||||
MAIN_PROVIERS = [
|
||||
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
|
||||
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
|
||||
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia
|
||||
]
|
||||
|
||||
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]
|
||||
|
||||
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS]
|
||||
|
||||
LABELS = {
|
||||
"default": "Default",
|
||||
"openai": "OpenAI: ChatGPT",
|
||||
|
|
@ -31,6 +41,7 @@ LABELS = {
|
|||
"mistral": "Mistral",
|
||||
"PollinationsAI": "Pollinations AI",
|
||||
"perplexity": "Perplexity Labs",
|
||||
"openrouter": "OpenRouter",
|
||||
"video": "Video Generation",
|
||||
"image": "Image Generation",
|
||||
"other": "Other Models",
|
||||
|
|
@ -109,6 +120,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
|
||||
groups["openai"].append(model)
|
||||
added = True
|
||||
# Check for openrouter models
|
||||
elif model.startswith(("openrouter:")):
|
||||
groups["openrouter"].append(model)
|
||||
added = True
|
||||
# Check for video models
|
||||
elif model in cls.video_models:
|
||||
groups["video"].append(model)
|
||||
|
|
@ -159,7 +174,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
all_models = [cls.default_model] + list(model_with_providers.keys())
|
||||
|
||||
# Process special providers
|
||||
for provider in [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]:
|
||||
for provider in SPECIAL_PROVIDERS:
|
||||
provider: ProviderType = provider
|
||||
if not provider.working or provider.get_parent() in ignored:
|
||||
continue
|
||||
|
|
@ -214,7 +229,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
return name
|
||||
|
||||
# Process HAR providers
|
||||
for provider in [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia]:
|
||||
for provider in SPECIAL_PROVIDERS2:
|
||||
if not provider.working or provider.get_parent() in ignored:
|
||||
continue
|
||||
new_models = provider.get_models()
|
||||
|
|
@ -224,7 +239,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
# Add original models too, not just cleaned names
|
||||
all_models.extend(new_models)
|
||||
|
||||
model_map = {clean_name(model): model for model in new_models}
|
||||
model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models}
|
||||
if not provider.model_aliases:
|
||||
provider.model_aliases = {}
|
||||
provider.model_aliases.update(model_map)
|
||||
|
|
@ -262,6 +277,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
stream: bool = True,
|
||||
media: MediaListType = None,
|
||||
ignored: list[str] = [],
|
||||
api_key: Union[str, dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
cls.get_models(ignored=ignored)
|
||||
|
|
@ -284,7 +300,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if "tools" in kwargs:
|
||||
providers = [PollinationsAI]
|
||||
elif "audio" in kwargs or "audio" in kwargs.get("modalities", []):
|
||||
providers = [PollinationsAI, EdgeTTS, gTTS]
|
||||
if kwargs.get("audio", {}).get("language") is None:
|
||||
providers = [PollinationsAI, OpenAIFM, Gemini]
|
||||
else:
|
||||
providers = [PollinationsAI, OpenAIFM, EdgeTTS, gTTS]
|
||||
elif has_audio:
|
||||
providers = [PollinationsAI, Microsoft_Phi_4_Multimodal, MarkItDown]
|
||||
elif has_image:
|
||||
|
|
@ -297,18 +316,18 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
model = None
|
||||
providers.append(provider)
|
||||
else:
|
||||
for provider in [
|
||||
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
|
||||
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
|
||||
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia,
|
||||
]:
|
||||
extra_providers = []
|
||||
if isinstance(api_key, dict):
|
||||
for provider in api_key:
|
||||
if provider in __map__ and __map__[provider] not in MAIN_PROVIERS:
|
||||
extra_providers.append(__map__[provider])
|
||||
for provider in MAIN_PROVIERS + extra_providers:
|
||||
if provider.working:
|
||||
if not model or model in provider.get_models() or model in provider.model_aliases:
|
||||
providers.append(provider)
|
||||
if model in models.__models__:
|
||||
for provider in models.__models__[model][1]:
|
||||
providers.append(provider)
|
||||
|
||||
providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored]
|
||||
providers = list({provider.__name__: provider for provider in providers}.values())
|
||||
|
||||
|
|
@ -320,6 +339,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
messages,
|
||||
stream=stream,
|
||||
media=media,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -126,12 +126,30 @@ class AbstractProvider(BaseProvider):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
def create_function(cls, *args, **kwargs) -> CreateResult:
|
||||
"""
|
||||
Creates a completion using the synchronous method.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
return cls.create_completion(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async
|
||||
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
|
||||
"""
|
||||
Creates a completion using the synchronous method.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
return cls.create_async(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
|
||||
|
|
@ -264,14 +282,6 @@ class AsyncProvider(AbstractProvider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async
|
||||
|
||||
class AsyncGeneratorProvider(AbstractProvider):
|
||||
"""
|
||||
Provides asynchronous generator functionality for streaming results.
|
||||
|
|
@ -331,12 +341,17 @@ class AsyncGeneratorProvider(AbstractProvider):
|
|||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
|
||||
"""
|
||||
Creates a completion using the synchronous method.
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async_generator
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
return cls.create_async_generator(*args, **kwargs)
|
||||
|
||||
class ProviderModelMixin:
|
||||
default_model: str = None
|
||||
|
|
@ -417,14 +432,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
|||
return to_sync_generator(auth_result)
|
||||
return asyncio.run(auth_result)
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async_generator
|
||||
|
||||
@classmethod
|
||||
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
|
||||
if auth_result is not None:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ class IterListProvider(BaseRetryProvider):
|
|||
self.shuffle = shuffle
|
||||
self.working = True
|
||||
self.last_provider: Type[BaseProvider] = None
|
||||
self.add_api_key = False
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
|
|
@ -35,6 +34,7 @@ class IterListProvider(BaseRetryProvider):
|
|||
stream: bool = False,
|
||||
ignore_stream: bool = False,
|
||||
ignored: list[str] = [],
|
||||
api_key: str = None,
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
|
|
@ -55,8 +55,11 @@ class IterListProvider(BaseRetryProvider):
|
|||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
||||
extra_body = kwargs.copy()
|
||||
if isinstance(api_key, dict):
|
||||
extra_body["api_key"] = api_key.get(provider.get_parent())
|
||||
try:
|
||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||
response = provider.create_function(model, messages, stream=stream, **extra_body)
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
|
@ -66,7 +69,7 @@ class IterListProvider(BaseRetryProvider):
|
|||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
debug.error(f"{provider.__name__}:", e)
|
||||
if started:
|
||||
raise e
|
||||
yield e
|
||||
|
|
@ -92,12 +95,12 @@ class IterListProvider(BaseRetryProvider):
|
|||
debug.log(f"Using {provider.__name__} provider" + (f" and {model} model" if model else ""))
|
||||
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
||||
extra_body = kwargs.copy()
|
||||
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]:
|
||||
extra_body["api_key"] = api_key
|
||||
if isinstance(api_key, dict):
|
||||
extra_body["api_key"] = api_key.get(provider.get_parent())
|
||||
if conversation is not None and hasattr(conversation, provider.__name__):
|
||||
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||
try:
|
||||
response = provider.get_async_create_function()(model, messages, stream=stream, **extra_body)
|
||||
response = provider.async_create_function(model, messages, stream=stream, **extra_body)
|
||||
if hasattr(response, "__aiter__"):
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, JsonConversation):
|
||||
|
|
@ -118,18 +121,15 @@ class IterListProvider(BaseRetryProvider):
|
|||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
debug.error(f"{provider.__name__}:", e)
|
||||
if started:
|
||||
raise e
|
||||
yield e
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
def get_create_function(self) -> callable:
|
||||
return self.create_completion
|
||||
|
||||
def get_async_create_function(self) -> callable:
|
||||
return self.create_async_generator
|
||||
create_function = create_completion
|
||||
async_create_function = create_async_generator
|
||||
|
||||
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
|
||||
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
|
||||
|
|
@ -156,7 +156,6 @@ class RetryProvider(IterListProvider):
|
|||
super().__init__(providers, shuffle)
|
||||
self.single_provider_retry = single_provider_retry
|
||||
self.max_retries = max_retries
|
||||
self.add_api_key = True
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
|
|
@ -185,7 +184,7 @@ class RetryProvider(IterListProvider):
|
|||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||
response = provider.create_function(model, messages, stream=stream, **kwargs)
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
|
|
@ -218,7 +217,7 @@ class RetryProvider(IterListProvider):
|
|||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
||||
response = provider.async_create_function(model, messages, stream=stream, **kwargs)
|
||||
if hasattr(response, "__aiter__"):
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||
finish = None
|
||||
chunks = []
|
||||
has_usage = False
|
||||
async for chunk in provider.get_async_create_function()(
|
||||
async for chunk in provider.async_create_function(
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
|
|
|||
|
|
@ -25,26 +25,8 @@ class BaseProvider(ABC):
|
|||
supports_message_history: bool = False
|
||||
supports_system_message: bool = False
|
||||
params: str
|
||||
|
||||
@abstractmethod
|
||||
def get_create_function() -> callable:
|
||||
"""
|
||||
Get the create function for the provider.
|
||||
|
||||
Returns:
|
||||
callable: The create function.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_async_create_function() -> callable:
|
||||
"""
|
||||
Get the async create function for the provider.
|
||||
|
||||
Returns:
|
||||
callable: The create function.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
create_function: callable
|
||||
async_create_function: callable
|
||||
|
||||
@classmethod
|
||||
def get_dict(cls) -> Dict[str, str]:
|
||||
|
|
|
|||
|
|
@ -246,8 +246,7 @@ async def async_iter_run_tools(
|
|||
kwargs.update(extra_kwargs)
|
||||
|
||||
# Generate response
|
||||
create_function = provider.get_async_create_function()
|
||||
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
|
||||
response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
|
||||
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue