mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
feat: refactor provider create functions to class attributes and update calls
- Added `create_function` and `async_create_function` class attributes with default implementations in `base_provider.py` for `AbstractProvider`, `AsyncProvider`, and `AsyncGeneratorProvider` - Updated `get_create_function` and `get_async_create_function` methods to return these class attributes - Replaced calls to `provider.get_create_function()` and `provider.get_async_create_function()` with direct attribute access `provider.create_function` and `provider.async_create_function` across `g4f/__init__.py`, `g4f/client/__init__.py`, `g4f/providers/retry_provider.py`, and `g4f/tools/run_tools.py` - Removed redundant `get_create_function` and `get_async_create_function` methods from `providers/base_provider.py` and `providers/types.py` - Ensured all provider response calls now use the class attributes for creating completions asynchronously and synchronously as needed
This commit is contained in:
parent
5734a06193
commit
2befef988b
15 changed files with 142 additions and 111 deletions
|
|
@ -10,20 +10,20 @@ from ..requests import StreamSession, raise_for_status
|
||||||
from ..errors import ModelNotFoundError
|
from ..errors import ModelNotFoundError
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
|
|
||||||
class Together(OpenaiTemplate):
|
class Together(OpenaiTemplate):
|
||||||
label = "Together"
|
label = "Together"
|
||||||
url = "https://together.xyz"
|
url = "https://together.xyz"
|
||||||
|
login_url = "https://api.together.ai/"
|
||||||
api_base = "https://api.together.xyz/v1"
|
api_base = "https://api.together.xyz/v1"
|
||||||
activation_endpoint = "https://www.codegeneration.ai/activate-v2"
|
activation_endpoint = "https://www.codegeneration.ai/activate-v2"
|
||||||
models_endpoint = "https://api.together.xyz/v1/models"
|
models_endpoint = "https://api.together.xyz/v1/models"
|
||||||
|
|
||||||
working = True
|
working = True
|
||||||
needs_auth = False
|
needs_auth = False
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
|
||||||
default_model = 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
|
default_model = 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
|
||||||
default_vision_model = default_model
|
default_vision_model = default_model
|
||||||
default_image_model = 'black-forest-labs/FLUX.1.1-pro'
|
default_image_model = 'black-forest-labs/FLUX.1.1-pro'
|
||||||
|
|
@ -43,7 +43,7 @@ class Together(OpenaiTemplate):
|
||||||
model_configs = {} # Store model configurations including stop tokens
|
model_configs = {} # Store model configurations including stop tokens
|
||||||
_models_cached = False
|
_models_cached = False
|
||||||
_api_key_cache = None
|
_api_key_cache = None
|
||||||
|
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
### Models Chat/Language ###
|
### Models Chat/Language ###
|
||||||
# meta-llama
|
# meta-llama
|
||||||
|
|
|
||||||
|
|
@ -682,7 +682,12 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
|
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
|
||||||
page = await browser.get(cls.url)
|
page = await browser.get(cls.url)
|
||||||
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
|
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
|
||||||
while not await page.evaluate("document.getElementById('prompt-textarea')?.id"):
|
textarea = None
|
||||||
|
while not textarea:
|
||||||
|
try:
|
||||||
|
textarea = await page.evaluate("document.getElementById('prompt-textarea')?.id")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
while not await page.evaluate("document.querySelector('[data-testid=\"send-button\"]')?.type"):
|
while not await page.evaluate("document.querySelector('[data-testid=\"send-button\"]')?.type"):
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from .. import debug
|
||||||
|
|
||||||
class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
|
class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
label = "Puter.js"
|
label = "Puter.js"
|
||||||
|
parent = "Puter"
|
||||||
url = "https://docs.puter.com/playground"
|
url = "https://docs.puter.com/playground"
|
||||||
login_url = "https://github.com/HeyPuter/puter-cli"
|
login_url = "https://github.com/HeyPuter/puter-cli"
|
||||||
api_endpoint = "https://api.puter.com/drivers/call"
|
api_endpoint = "https://api.puter.com/drivers/call"
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from ..helper import filter_none, format_media_prompt
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
|
||||||
from ...tools.media import render_messages
|
from ...tools.media import render_messages
|
||||||
from ...errors import MissingAuthError, ResponseError
|
from ...errors import MissingAuthError, ResponseError
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
@ -93,6 +93,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
cls.raise_error(data, response.status)
|
cls.raise_error(data, response.status)
|
||||||
|
model = data.get("model")
|
||||||
|
if model:
|
||||||
|
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||||
return
|
return
|
||||||
|
|
@ -121,6 +124,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
cls.raise_error(data, response.status)
|
cls.raise_error(data, response.status)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
model = data.get("model")
|
||||||
|
if model:
|
||||||
|
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
if "content" in choice["message"] and choice["message"]["content"]:
|
if "content" in choice["message"] and choice["message"]["content"]:
|
||||||
yield choice["message"]["content"].strip()
|
yield choice["message"]["content"].strip()
|
||||||
|
|
@ -134,8 +140,13 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
elif content_type.startswith("text/event-stream"):
|
elif content_type.startswith("text/event-stream"):
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
first = True
|
first = True
|
||||||
|
model_returned = False
|
||||||
async for data in response.sse():
|
async for data in response.sse():
|
||||||
cls.raise_error(data)
|
cls.raise_error(data)
|
||||||
|
model = data.get("model")
|
||||||
|
if not model_returned and model:
|
||||||
|
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||||
|
model_returned = True
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
if "content" in choice["delta"] and choice["delta"]["content"]:
|
if "content" in choice["delta"] and choice["delta"]["content"]:
|
||||||
delta = choice["delta"]["content"]
|
delta = choice["delta"]["content"]
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class ChatCompletion:
|
||||||
if ignore_stream:
|
if ignore_stream:
|
||||||
kwargs["ignore_stream"] = True
|
kwargs["ignore_stream"] = True
|
||||||
|
|
||||||
result = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
result = provider.create_function(model, messages, stream=stream, **kwargs)
|
||||||
|
|
||||||
return result if stream or ignore_stream else concat_chunks(result)
|
return result if stream or ignore_stream else concat_chunks(result)
|
||||||
|
|
||||||
|
|
@ -76,7 +76,7 @@ class ChatCompletion:
|
||||||
if ignore_stream:
|
if ignore_stream:
|
||||||
kwargs["ignore_stream"] = True
|
kwargs["ignore_stream"] = True
|
||||||
|
|
||||||
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
result = provider.async_create_function(model, messages, stream=stream, **kwargs)
|
||||||
|
|
||||||
if not stream and not ignore_stream:
|
if not stream and not ignore_stream:
|
||||||
if hasattr(result, "__aiter__"):
|
if hasattr(result, "__aiter__"):
|
||||||
|
|
|
||||||
|
|
@ -490,7 +490,7 @@ class Api:
|
||||||
config.provider = provider
|
config.provider = provider
|
||||||
if config.provider is None:
|
if config.provider is None:
|
||||||
config.provider = AppConfig.media_provider
|
config.provider = AppConfig.media_provider
|
||||||
if credentials is not None and credentials.credentials != "secret":
|
if config.api_key is None and credentials is not None and credentials.credentials != "secret":
|
||||||
config.api_key = credentials.credentials
|
config.api_key = credentials.credentials
|
||||||
try:
|
try:
|
||||||
response = await self.client.images.generate(
|
response = await self.client.images.generate(
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ class RequestConfig(BaseModel):
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stop: Union[list[str], str, None] = None
|
stop: Union[list[str], str, None] = None
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[Union[str, dict[str, str]]] = None
|
||||||
api_base: str = None
|
api_base: str = None
|
||||||
web_search: Optional[bool] = None
|
web_search: Optional[bool] = None
|
||||||
proxy: Optional[str] = None
|
proxy: Optional[str] = None
|
||||||
|
|
@ -70,6 +70,8 @@ class ImageGenerationConfig(BaseModel):
|
||||||
negative_prompt: Optional[str] = None
|
negative_prompt: Optional[str] = None
|
||||||
resolution: Optional[str] = None
|
resolution: Optional[str] = None
|
||||||
audio: Optional[dict] = None
|
audio: Optional[dict] = None
|
||||||
|
download_media: bool = True
|
||||||
|
|
||||||
|
|
||||||
@model_validator(mode='before')
|
@model_validator(mode='before')
|
||||||
def parse_size(cls, values):
|
def parse_size(cls, values):
|
||||||
|
|
|
||||||
|
|
@ -375,7 +375,7 @@ class Completions:
|
||||||
kwargs["ignore_stream"] = True
|
kwargs["ignore_stream"] = True
|
||||||
|
|
||||||
response = iter_run_tools(
|
response = iter_run_tools(
|
||||||
provider.get_create_function(),
|
provider.create_function,
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
@ -462,7 +462,7 @@ class Images:
|
||||||
if isinstance(provider_handler, IterListProvider):
|
if isinstance(provider_handler, IterListProvider):
|
||||||
for provider in provider_handler.providers:
|
for provider in provider_handler.providers:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
|
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
provider_name = provider.__name__
|
provider_name = provider.__name__
|
||||||
break
|
break
|
||||||
|
|
@ -485,21 +485,25 @@ class Images:
|
||||||
|
|
||||||
async def _generate_image_response(
|
async def _generate_image_response(
|
||||||
self,
|
self,
|
||||||
provider_handler,
|
provider_handler: ProviderType,
|
||||||
provider_name,
|
provider_name: str,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_prefix: str = "Generate a image: ",
|
prompt_prefix: str = "Generate a image: ",
|
||||||
|
api_key: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> MediaResponse:
|
) -> MediaResponse:
|
||||||
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
|
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
|
||||||
items: list[MediaResponse] = []
|
items: list[MediaResponse] = []
|
||||||
|
if isinstance(api_key, dict):
|
||||||
|
api_key = api_key.get(provider_handler.get_parent())
|
||||||
if hasattr(provider_handler, "create_async_generator"):
|
if hasattr(provider_handler, "create_async_generator"):
|
||||||
async for item in provider_handler.create_async_generator(
|
async for item in provider_handler.create_async_generator(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
api_key=api_key,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if isinstance(item, (MediaResponse, AudioResponse)):
|
if isinstance(item, (MediaResponse, AudioResponse)):
|
||||||
|
|
@ -510,6 +514,7 @@ class Images:
|
||||||
messages,
|
messages,
|
||||||
True,
|
True,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
api_key=api_key,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if isinstance(item, (MediaResponse, AudioResponse)):
|
if isinstance(item, (MediaResponse, AudioResponse)):
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,7 @@ async def copy_media(
|
||||||
if media_type not in ("application/octet-stream", "binary/octet-stream"):
|
if media_type not in ("application/octet-stream", "binary/octet-stream"):
|
||||||
if media_type not in MEDIA_TYPE_MAP:
|
if media_type not in MEDIA_TYPE_MAP:
|
||||||
raise ValueError(f"Unsupported media type: {media_type}")
|
raise ValueError(f"Unsupported media type: {media_type}")
|
||||||
if not media_extension:
|
if target is None and not media_extension:
|
||||||
media_extension = f".{MEDIA_TYPE_MAP[media_type]}"
|
media_extension = f".{MEDIA_TYPE_MAP[media_type]}"
|
||||||
target_path = f"{target_path}{media_extension}"
|
target_path = f"{target_path}{media_extension}"
|
||||||
with open(target_path, "wb") as f:
|
with open(target_path, "wb") as f:
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,32 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Set, Optional, Tuple, Any
|
from ..typing import AsyncResult, Messages, MediaListType, Union
|
||||||
from ..typing import AsyncResult, Messages, MediaListType
|
|
||||||
from ..errors import ModelNotFoundError
|
from ..errors import ModelNotFoundError
|
||||||
from ..image import is_data_an_audio
|
from ..image import is_data_an_audio
|
||||||
from ..providers.retry_provider import IterListProvider
|
from ..providers.retry_provider import IterListProvider
|
||||||
from ..providers.types import ProviderType
|
from ..providers.types import ProviderType
|
||||||
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
||||||
from ..Provider.hf_space import HuggingSpace
|
from ..Provider.hf_space import HuggingSpace
|
||||||
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI
|
from ..Provider import __map__
|
||||||
|
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
||||||
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
|
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
|
||||||
from ..Provider import EdgeTTS, gTTS, MarkItDown
|
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
|
||||||
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
|
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from .. import Provider
|
from .. import Provider
|
||||||
from .. import models
|
from .. import models
|
||||||
|
|
||||||
|
MAIN_PROVIERS = [
|
||||||
|
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
|
||||||
|
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
|
||||||
|
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]
|
||||||
|
|
||||||
|
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS]
|
||||||
|
|
||||||
LABELS = {
|
LABELS = {
|
||||||
"default": "Default",
|
"default": "Default",
|
||||||
"openai": "OpenAI: ChatGPT",
|
"openai": "OpenAI: ChatGPT",
|
||||||
|
|
@ -31,6 +41,7 @@ LABELS = {
|
||||||
"mistral": "Mistral",
|
"mistral": "Mistral",
|
||||||
"PollinationsAI": "Pollinations AI",
|
"PollinationsAI": "Pollinations AI",
|
||||||
"perplexity": "Perplexity Labs",
|
"perplexity": "Perplexity Labs",
|
||||||
|
"openrouter": "OpenRouter",
|
||||||
"video": "Video Generation",
|
"video": "Video Generation",
|
||||||
"image": "Image Generation",
|
"image": "Image Generation",
|
||||||
"other": "Other Models",
|
"other": "Other Models",
|
||||||
|
|
@ -45,16 +56,16 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
|
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
|
||||||
unsorted_models = cls.get_models(ignored=ignored)
|
unsorted_models = cls.get_models(ignored=ignored)
|
||||||
groups = {key: [] for key in LABELS.keys()}
|
groups = {key: [] for key in LABELS.keys()}
|
||||||
|
|
||||||
# Always add default first
|
# Always add default first
|
||||||
groups["default"].append("default")
|
groups["default"].append("default")
|
||||||
|
|
||||||
for model in unsorted_models:
|
for model in unsorted_models:
|
||||||
if model == "default":
|
if model == "default":
|
||||||
continue # Already added
|
continue # Already added
|
||||||
|
|
||||||
added = False
|
added = False
|
||||||
|
|
||||||
# Check for PollinationsAI models (with prefix)
|
# Check for PollinationsAI models (with prefix)
|
||||||
if model.startswith("PollinationsAI:"):
|
if model.startswith("PollinationsAI:"):
|
||||||
groups["PollinationsAI"].append(model)
|
groups["PollinationsAI"].append(model)
|
||||||
|
|
@ -109,6 +120,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
|
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
|
||||||
groups["openai"].append(model)
|
groups["openai"].append(model)
|
||||||
added = True
|
added = True
|
||||||
|
# Check for openrouter models
|
||||||
|
elif model.startswith(("openrouter:")):
|
||||||
|
groups["openrouter"].append(model)
|
||||||
|
added = True
|
||||||
# Check for video models
|
# Check for video models
|
||||||
elif model in cls.video_models:
|
elif model in cls.video_models:
|
||||||
groups["video"].append(model)
|
groups["video"].append(model)
|
||||||
|
|
@ -117,7 +132,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower():
|
elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower():
|
||||||
groups["image"].append(model)
|
groups["image"].append(model)
|
||||||
added = True
|
added = True
|
||||||
|
|
||||||
# If not categorized, check for special cases then put in other
|
# If not categorized, check for special cases then put in other
|
||||||
if not added:
|
if not added:
|
||||||
# CodeLlama is Meta's model
|
# CodeLlama is Meta's model
|
||||||
|
|
@ -128,7 +143,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
groups["phi"].append(model)
|
groups["phi"].append(model)
|
||||||
else:
|
else:
|
||||||
groups["other"].append(model)
|
groups["other"].append(model)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{"group": LABELS[group], "models": names} for group, names in groups.items()
|
{"group": LABELS[group], "models": names} for group, names in groups.items()
|
||||||
]
|
]
|
||||||
|
|
@ -157,9 +172,9 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
|
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
|
||||||
}
|
}
|
||||||
all_models = [cls.default_model] + list(model_with_providers.keys())
|
all_models = [cls.default_model] + list(model_with_providers.keys())
|
||||||
|
|
||||||
# Process special providers
|
# Process special providers
|
||||||
for provider in [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]:
|
for provider in SPECIAL_PROVIDERS:
|
||||||
provider: ProviderType = provider
|
provider: ProviderType = provider
|
||||||
if not provider.working or provider.get_parent() in ignored:
|
if not provider.working or provider.get_parent() in ignored:
|
||||||
continue
|
continue
|
||||||
|
|
@ -186,7 +201,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
cls.image_models.extend(arta_models)
|
cls.image_models.extend(arta_models)
|
||||||
else:
|
else:
|
||||||
all_models.extend(provider.get_models())
|
all_models.extend(provider.get_models())
|
||||||
|
|
||||||
# Update special model lists
|
# Update special model lists
|
||||||
if hasattr(provider, 'image_models'):
|
if hasattr(provider, 'image_models'):
|
||||||
cls.image_models.extend(provider.image_models)
|
cls.image_models.extend(provider.image_models)
|
||||||
|
|
@ -194,7 +209,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
cls.vision_models.extend(provider.vision_models)
|
cls.vision_models.extend(provider.vision_models)
|
||||||
if hasattr(provider, 'video_models'):
|
if hasattr(provider, 'video_models'):
|
||||||
cls.video_models.extend(provider.video_models)
|
cls.video_models.extend(provider.video_models)
|
||||||
|
|
||||||
# Clean model names function
|
# Clean model names function
|
||||||
def clean_name(name: str) -> str:
|
def clean_name(name: str) -> str:
|
||||||
name = name.split("/")[-1].split(":")[0].lower()
|
name = name.split("/")[-1].split(":")[0].lower()
|
||||||
|
|
@ -212,24 +227,24 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
name = name.replace("llama3", "llama-3")
|
name = name.replace("llama3", "llama-3")
|
||||||
name = name.replace("flux.1-", "flux-")
|
name = name.replace("flux.1-", "flux-")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
# Process HAR providers
|
# Process HAR providers
|
||||||
for provider in [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia]:
|
for provider in SPECIAL_PROVIDERS2:
|
||||||
if not provider.working or provider.get_parent() in ignored:
|
if not provider.working or provider.get_parent() in ignored:
|
||||||
continue
|
continue
|
||||||
new_models = provider.get_models()
|
new_models = provider.get_models()
|
||||||
if provider == HuggingFaceMedia:
|
if provider == HuggingFaceMedia:
|
||||||
new_models = provider.video_models
|
new_models = provider.video_models
|
||||||
|
|
||||||
# Add original models too, not just cleaned names
|
# Add original models too, not just cleaned names
|
||||||
all_models.extend(new_models)
|
all_models.extend(new_models)
|
||||||
|
|
||||||
model_map = {clean_name(model): model for model in new_models}
|
model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models}
|
||||||
if not provider.model_aliases:
|
if not provider.model_aliases:
|
||||||
provider.model_aliases = {}
|
provider.model_aliases = {}
|
||||||
provider.model_aliases.update(model_map)
|
provider.model_aliases.update(model_map)
|
||||||
all_models.extend(list(model_map.keys()))
|
all_models.extend(list(model_map.keys()))
|
||||||
|
|
||||||
# Update special model lists with both original and cleaned names
|
# Update special model lists with both original and cleaned names
|
||||||
if hasattr(provider, 'image_models'):
|
if hasattr(provider, 'image_models'):
|
||||||
cls.image_models.extend(provider.image_models)
|
cls.image_models.extend(provider.image_models)
|
||||||
|
|
@ -240,18 +255,18 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if hasattr(provider, 'video_models'):
|
if hasattr(provider, 'video_models'):
|
||||||
cls.video_models.extend(provider.video_models)
|
cls.video_models.extend(provider.video_models)
|
||||||
cls.video_models.extend([clean_name(model) for model in provider.video_models])
|
cls.video_models.extend([clean_name(model) for model in provider.video_models])
|
||||||
|
|
||||||
# Process audio providers
|
# Process audio providers
|
||||||
for provider in [Microsoft_Phi_4_Multimodal, PollinationsAI]:
|
for provider in [Microsoft_Phi_4_Multimodal, PollinationsAI]:
|
||||||
if provider.working and provider.get_parent() not in ignored:
|
if provider.working and provider.get_parent() not in ignored:
|
||||||
cls.audio_models.update(provider.audio_models)
|
cls.audio_models.update(provider.audio_models)
|
||||||
|
|
||||||
# Update model counts
|
# Update model counts
|
||||||
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
||||||
|
|
||||||
# Deduplicate and store
|
# Deduplicate and store
|
||||||
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
||||||
|
|
||||||
return cls.models_storage[ignored_key]
|
return cls.models_storage[ignored_key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -262,6 +277,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
media: MediaListType = None,
|
media: MediaListType = None,
|
||||||
ignored: list[str] = [],
|
ignored: list[str] = [],
|
||||||
|
api_key: Union[str, dict[str, str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
cls.get_models(ignored=ignored)
|
cls.get_models(ignored=ignored)
|
||||||
|
|
@ -284,7 +300,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if "tools" in kwargs:
|
if "tools" in kwargs:
|
||||||
providers = [PollinationsAI]
|
providers = [PollinationsAI]
|
||||||
elif "audio" in kwargs or "audio" in kwargs.get("modalities", []):
|
elif "audio" in kwargs or "audio" in kwargs.get("modalities", []):
|
||||||
providers = [PollinationsAI, EdgeTTS, gTTS]
|
if kwargs.get("audio", {}).get("language") is None:
|
||||||
|
providers = [PollinationsAI, OpenAIFM, Gemini]
|
||||||
|
else:
|
||||||
|
providers = [PollinationsAI, OpenAIFM, EdgeTTS, gTTS]
|
||||||
elif has_audio:
|
elif has_audio:
|
||||||
providers = [PollinationsAI, Microsoft_Phi_4_Multimodal, MarkItDown]
|
providers = [PollinationsAI, Microsoft_Phi_4_Multimodal, MarkItDown]
|
||||||
elif has_image:
|
elif has_image:
|
||||||
|
|
@ -297,29 +316,30 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
model = None
|
model = None
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
else:
|
else:
|
||||||
for provider in [
|
extra_providers = []
|
||||||
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
|
if isinstance(api_key, dict):
|
||||||
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
|
for provider in api_key:
|
||||||
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia,
|
if provider in __map__ and __map__[provider] not in MAIN_PROVIERS:
|
||||||
]:
|
extra_providers.append(__map__[provider])
|
||||||
|
for provider in MAIN_PROVIERS + extra_providers:
|
||||||
if provider.working:
|
if provider.working:
|
||||||
if not model or model in provider.get_models() or model in provider.model_aliases:
|
if not model or model in provider.get_models() or model in provider.model_aliases:
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
if model in models.__models__:
|
if model in models.__models__:
|
||||||
for provider in models.__models__[model][1]:
|
for provider in models.__models__[model][1]:
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
|
|
||||||
providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored]
|
providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored]
|
||||||
providers = list({provider.__name__: provider for provider in providers}.values())
|
providers = list({provider.__name__: provider for provider in providers}.values())
|
||||||
|
|
||||||
if len(providers) == 0:
|
if len(providers) == 0:
|
||||||
raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.")
|
raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.")
|
||||||
|
|
||||||
async for chunk in IterListProvider(providers).create_async_generator(
|
async for chunk in IterListProvider(providers).create_async_generator(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
media=media,
|
media=media,
|
||||||
|
api_key=api_key,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
|
||||||
|
|
@ -126,12 +126,30 @@ class AbstractProvider(BaseProvider):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_create_function(cls) -> callable:
|
def create_function(cls, *args, **kwargs) -> CreateResult:
|
||||||
return cls.create_completion
|
"""
|
||||||
|
Creates a completion using the synchronous method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreateResult: The result of the completion creation.
|
||||||
|
"""
|
||||||
|
return cls.create_completion(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_async_create_function(cls) -> callable:
|
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
|
||||||
return cls.create_async
|
"""
|
||||||
|
Creates a completion using the synchronous method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreateResult: The result of the completion creation.
|
||||||
|
"""
|
||||||
|
return cls.create_async(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
|
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
|
||||||
|
|
@ -264,14 +282,6 @@ class AsyncProvider(AbstractProvider):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_create_function(cls) -> callable:
|
|
||||||
return cls.create_completion
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_async_create_function(cls) -> callable:
|
|
||||||
return cls.create_async
|
|
||||||
|
|
||||||
class AsyncGeneratorProvider(AbstractProvider):
|
class AsyncGeneratorProvider(AbstractProvider):
|
||||||
"""
|
"""
|
||||||
Provides asynchronous generator functionality for streaming results.
|
Provides asynchronous generator functionality for streaming results.
|
||||||
|
|
@ -331,12 +341,17 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_create_function(cls) -> callable:
|
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
|
||||||
return cls.create_completion
|
"""
|
||||||
|
Creates a completion using the synchronous method.
|
||||||
|
|
||||||
@classmethod
|
Args:
|
||||||
def get_async_create_function(cls) -> callable:
|
**kwargs: Additional keyword arguments.
|
||||||
return cls.create_async_generator
|
|
||||||
|
Returns:
|
||||||
|
CreateResult: The result of the completion creation.
|
||||||
|
"""
|
||||||
|
return cls.create_async_generator(*args, **kwargs)
|
||||||
|
|
||||||
class ProviderModelMixin:
|
class ProviderModelMixin:
|
||||||
default_model: str = None
|
default_model: str = None
|
||||||
|
|
@ -417,14 +432,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||||
return to_sync_generator(auth_result)
|
return to_sync_generator(auth_result)
|
||||||
return asyncio.run(auth_result)
|
return asyncio.run(auth_result)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_create_function(cls) -> callable:
|
|
||||||
return cls.create_completion
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_async_create_function(cls) -> callable:
|
|
||||||
return cls.create_async_generator
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
|
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
|
||||||
if auth_result is not None:
|
if auth_result is not None:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ class IterListProvider(BaseRetryProvider):
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.working = True
|
self.working = True
|
||||||
self.last_provider: Type[BaseProvider] = None
|
self.last_provider: Type[BaseProvider] = None
|
||||||
self.add_api_key = False
|
|
||||||
|
|
||||||
def create_completion(
|
def create_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -35,6 +34,7 @@ class IterListProvider(BaseRetryProvider):
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
ignore_stream: bool = False,
|
ignore_stream: bool = False,
|
||||||
ignored: list[str] = [],
|
ignored: list[str] = [],
|
||||||
|
api_key: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> CreateResult:
|
) -> CreateResult:
|
||||||
"""
|
"""
|
||||||
|
|
@ -55,8 +55,11 @@ class IterListProvider(BaseRetryProvider):
|
||||||
self.last_provider = provider
|
self.last_provider = provider
|
||||||
debug.log(f"Using {provider.__name__} provider")
|
debug.log(f"Using {provider.__name__} provider")
|
||||||
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
||||||
|
extra_body = kwargs.copy()
|
||||||
|
if isinstance(api_key, dict):
|
||||||
|
extra_body["api_key"] = api_key.get(provider.get_parent())
|
||||||
try:
|
try:
|
||||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
response = provider.create_function(model, messages, stream=stream, **extra_body)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk:
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
@ -66,7 +69,7 @@ class IterListProvider(BaseRetryProvider):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exceptions[provider.__name__] = e
|
exceptions[provider.__name__] = e
|
||||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
debug.error(f"{provider.__name__}:", e)
|
||||||
if started:
|
if started:
|
||||||
raise e
|
raise e
|
||||||
yield e
|
yield e
|
||||||
|
|
@ -92,12 +95,12 @@ class IterListProvider(BaseRetryProvider):
|
||||||
debug.log(f"Using {provider.__name__} provider" + (f" and {model} model" if model else ""))
|
debug.log(f"Using {provider.__name__} provider" + (f" and {model} model" if model else ""))
|
||||||
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
||||||
extra_body = kwargs.copy()
|
extra_body = kwargs.copy()
|
||||||
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]:
|
if isinstance(api_key, dict):
|
||||||
extra_body["api_key"] = api_key
|
extra_body["api_key"] = api_key.get(provider.get_parent())
|
||||||
if conversation is not None and hasattr(conversation, provider.__name__):
|
if conversation is not None and hasattr(conversation, provider.__name__):
|
||||||
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||||
try:
|
try:
|
||||||
response = provider.get_async_create_function()(model, messages, stream=stream, **extra_body)
|
response = provider.async_create_function(model, messages, stream=stream, **extra_body)
|
||||||
if hasattr(response, "__aiter__"):
|
if hasattr(response, "__aiter__"):
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
if isinstance(chunk, JsonConversation):
|
if isinstance(chunk, JsonConversation):
|
||||||
|
|
@ -118,18 +121,15 @@ class IterListProvider(BaseRetryProvider):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exceptions[provider.__name__] = e
|
exceptions[provider.__name__] = e
|
||||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
debug.error(f"{provider.__name__}:", e)
|
||||||
if started:
|
if started:
|
||||||
raise e
|
raise e
|
||||||
yield e
|
yield e
|
||||||
|
|
||||||
raise_exceptions(exceptions)
|
raise_exceptions(exceptions)
|
||||||
|
|
||||||
def get_create_function(self) -> callable:
|
create_function = create_completion
|
||||||
return self.create_completion
|
async_create_function = create_async_generator
|
||||||
|
|
||||||
def get_async_create_function(self) -> callable:
|
|
||||||
return self.create_async_generator
|
|
||||||
|
|
||||||
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
|
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
|
||||||
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
|
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
|
||||||
|
|
@ -156,7 +156,6 @@ class RetryProvider(IterListProvider):
|
||||||
super().__init__(providers, shuffle)
|
super().__init__(providers, shuffle)
|
||||||
self.single_provider_retry = single_provider_retry
|
self.single_provider_retry = single_provider_retry
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.add_api_key = True
|
|
||||||
|
|
||||||
def create_completion(
|
def create_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -185,7 +184,7 @@ class RetryProvider(IterListProvider):
|
||||||
try:
|
try:
|
||||||
if debug.logging:
|
if debug.logging:
|
||||||
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
response = provider.create_function(model, messages, stream=stream, **kwargs)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
yield chunk
|
yield chunk
|
||||||
if is_content(chunk):
|
if is_content(chunk):
|
||||||
|
|
@ -218,7 +217,7 @@ class RetryProvider(IterListProvider):
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
try:
|
try:
|
||||||
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||||
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
response = provider.async_create_function(model, messages, stream=stream, **kwargs)
|
||||||
if hasattr(response, "__aiter__"):
|
if hasattr(response, "__aiter__"):
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
||||||
finish = None
|
finish = None
|
||||||
chunks = []
|
chunks = []
|
||||||
has_usage = False
|
has_usage = False
|
||||||
async for chunk in provider.get_async_create_function()(
|
async for chunk in provider.async_create_function(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
|
||||||
|
|
@ -25,26 +25,8 @@ class BaseProvider(ABC):
|
||||||
supports_message_history: bool = False
|
supports_message_history: bool = False
|
||||||
supports_system_message: bool = False
|
supports_system_message: bool = False
|
||||||
params: str
|
params: str
|
||||||
|
create_function: callable
|
||||||
@abstractmethod
|
async_create_function: callable
|
||||||
def get_create_function() -> callable:
|
|
||||||
"""
|
|
||||||
Get the create function for the provider.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
callable: The create function.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_async_create_function() -> callable:
|
|
||||||
"""
|
|
||||||
Get the async create function for the provider.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
callable: The create function.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_dict(cls) -> Dict[str, str]:
|
def get_dict(cls) -> Dict[str, str]:
|
||||||
|
|
|
||||||
|
|
@ -246,8 +246,7 @@ async def async_iter_run_tools(
|
||||||
kwargs.update(extra_kwargs)
|
kwargs.update(extra_kwargs)
|
||||||
|
|
||||||
# Generate response
|
# Generate response
|
||||||
create_function = provider.get_async_create_function()
|
response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
|
||||||
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
|
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue