mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
fix: enhance retry logic and parameter handling in commit and provider code
- Added `--max-retries` argument to `parse_arguments()` in `commit.py` with default `MAX_RETRIES` - Updated `generate_commit_message()` to accept a `max_retries` parameter and iterate accordingly - Included check to raise immediately if `max_retries` is 1 within `generate_commit_message()` - Passed `args.max_retries` when calling `generate_commit_message()` in `main()` - In `g4f/Provider/har/__init__.py`, imported `ResponseError` and added check for network error to raise `ResponseError` - In `g4f/Provider/hf_space/Qwen_Qwen_3.py`, changed default model string and updated system prompt handling to use `get_system_prompt()` - In `g4f/Provider/needs_auth/LMArenaBeta.py`, modified callback to wait for cookie and turnstile response - In `g4f/Provider/needs_auth/PuterJS.py`, adjusted `get_models()` to filter out certain models - In `g4f/gui/server/api.py`, adjusted `get_model_data()` to handle models starting with "openrouter:" - In `g4f/providers/any_provider.py`, imported and used `ResponseError`; added logic to process `model_aliases` with updated model name resolution - Refined model name cleaning logic to handle additional patterns and replaced multiple regex patterns to better match version strings - Updated list of providers `PROVIERS_LIST_1`, `PROVIERS_LIST_2`, `PROVIERS_LIST_3`, and their usage to include new providers and adjust filtering - In `g4f/version.py`, added `get_git_version()` function, retrieved version with `git describe` command, instead of only relying on `get_github_version()`, increasing robustness
This commit is contained in:
parent
67231e8c40
commit
c12227a1cd
11 changed files with 125 additions and 122 deletions
|
|
@ -128,6 +128,8 @@ def parse_arguments():
|
|||
help="List available AI models and exit")
|
||||
parser.add_argument("--repo", type=str, default=".",
|
||||
help="Git repository path (default: current directory)")
|
||||
parser.add_argument("--max-retries", type=int, default=MAX_RETRIES,
|
||||
help="Maximum number of retries for AI generation (default: 3)")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -288,7 +290,7 @@ def show_spinner(duration: int = None):
|
|||
stop_spinner.set()
|
||||
raise
|
||||
|
||||
def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optional[str]:
|
||||
def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL, max_retries: int = MAX_RETRIES) -> Optional[str]:
|
||||
"""Generate a commit message based on the git diff"""
|
||||
if not diff_text or diff_text.strip() == "":
|
||||
return "No changes staged for commit"
|
||||
|
|
@ -324,7 +326,7 @@ def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optio
|
|||
IMPORTANT: Be 100% factual. Only mention code that was actually changed. Never invent or assume changes not shown in the diff. If unsure about a change's purpose, describe what changed rather than why. Output nothing except for the commit message, and don't surround it in quotes.
|
||||
"""
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Start spinner
|
||||
spinner = show_spinner()
|
||||
|
|
@ -352,7 +354,8 @@ def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optio
|
|||
spinner.set()
|
||||
sys.stdout.write("\r" + " " * 50 + "\r")
|
||||
sys.stdout.flush()
|
||||
|
||||
if max_retries == 1:
|
||||
raise e # If no retries, raise immediately
|
||||
print(f"Error generating commit message (attempt {attempt+1}/{MAX_RETRIES}): {e}")
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
print(f"Retrying in {RETRY_DELAY} seconds...")
|
||||
|
|
@ -464,7 +467,7 @@ def main():
|
|||
sys.exit(0)
|
||||
|
||||
print(f"Using model: {args.model}")
|
||||
commit_message = generate_commit_message(diff, args.model)
|
||||
commit_message = generate_commit_message(diff, args.model, args.max_retries)
|
||||
|
||||
if not commit_message:
|
||||
print("Failed to generate commit message after multiple attempts.")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from ...requests import DEFAULT_HEADERS, StreamSession, StreamResponse, FormData
|
|||
from ...providers.response import JsonConversation
|
||||
from ...tools.media import merge_media
|
||||
from ...image import to_bytes, is_accepted_format
|
||||
from ...errors import ResponseError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_last_user_message
|
||||
from ..openai.har_file import get_headers
|
||||
|
|
@ -139,6 +140,8 @@ class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if not line.startswith(b"data: "):
|
||||
continue
|
||||
for content in find_str(json.loads(line[6:]), 3):
|
||||
if "**NETWORK ERROR DUE TO HIGH TRAFFIC." in content:
|
||||
raise ResponseError(content)
|
||||
if content == '<span class="cursor"></span> ' or content == 'update':
|
||||
continue
|
||||
if content.endswith("▌"):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from ...providers.response import Reasoning, JsonConversation
|
|||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...errors import ModelNotFoundError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_last_user_message
|
||||
from ..helper import get_last_user_message, get_system_prompt
|
||||
from ... import debug
|
||||
|
||||
|
||||
|
|
@ -22,19 +22,19 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
supports_stream = True
|
||||
supports_system_message = True
|
||||
|
||||
default_model = "qwen3-235b-a22b"
|
||||
default_model = "qwen-3-235b"
|
||||
models = {
|
||||
default_model,
|
||||
"qwen3-32b",
|
||||
"qwen3-30b-a3b",
|
||||
"qwen3-14b",
|
||||
"qwen3-8b",
|
||||
"qwen3-4b",
|
||||
"qwen3-1.7b",
|
||||
"qwen3-0.6b",
|
||||
"qwen-3-32b",
|
||||
"qwen-3-30b-a3b",
|
||||
"qwen-3-14b",
|
||||
"qwen-3-8b",
|
||||
"qwen-3-4b",
|
||||
"qwen-3-1.7b",
|
||||
"qwen-3-0.6b",
|
||||
}
|
||||
model_aliases = {
|
||||
"qwen-3-235b": default_model,
|
||||
"qwen-3-235b": "qwen3-235b-a22b",
|
||||
"qwen-3-30b": "qwen3-30b-a3b",
|
||||
"qwen-3-32b": "qwen3-32b",
|
||||
"qwen-3-14b": "qwen3-14b",
|
||||
|
|
@ -76,12 +76,12 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
'Cache-Control': 'no-cache',
|
||||
}
|
||||
|
||||
sys_prompt = "\n".join([message['content'] for message in messages if message['role'] == 'system'])
|
||||
sys_prompt = sys_prompt if sys_prompt else "You are a helpful and harmless assistant."
|
||||
system_prompt = get_system_prompt(messages)
|
||||
system_prompt = system_prompt if system_prompt else "You are a helpful and harmless assistant."
|
||||
|
||||
payload_join = {"data": [
|
||||
get_last_user_message(messages),
|
||||
{"thinking_budget": thinking_budget, "model": cls.get_model(model), "sys_prompt": sys_prompt}, None, None],
|
||||
{"thinking_budget": thinking_budget, "model": cls.get_model(model), "sys_prompt": system_prompt}, None, None],
|
||||
"event_data": None, "fn_index": 13, "trigger_id": 31, "session_hash": conversation.session_hash
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
label = "LMArena Beta"
|
||||
url = "https://beta.lmarena.ai"
|
||||
api_endpoint = "https://beta.lmarena.ai/api/stream/create-evaluation"
|
||||
working = True
|
||||
working = has_nodriver
|
||||
|
||||
default_model = list(text_models.keys())[0]
|
||||
models = list(text_models) + list(image_models)
|
||||
|
|
@ -102,7 +102,8 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
async def callback(page):
|
||||
while not await page.evaluate('document.cookie.indexOf("arena-auth-prod-v1") >= 0'):
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(5)
|
||||
while await page.evaluate('document.querySelector(\'[name="cf-turnstile-response"]\').length > 0') :
|
||||
await asyncio.sleep(1)
|
||||
args = await get_args_from_nodriver(cls.url, proxy=proxy, callback=callback)
|
||||
except (RuntimeError, FileNotFoundError) as e:
|
||||
debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
|
||||
|
|
|
|||
|
|
@ -257,14 +257,15 @@ class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"lfm-7b": "openrouter:liquid/lfm-7b",
|
||||
"lfm-3b": "openrouter:liquid/lfm-3b",
|
||||
"lfm-40b": "openrouter:liquid/lfm-40b",
|
||||
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
def get_models(cls, api_key: str = None) -> list[str]:
|
||||
if not cls.models:
|
||||
try:
|
||||
url = "https://api.puter.com/puterai/chat/models/"
|
||||
cls.models = requests.get(url).json().get("models", [])
|
||||
cls.models = [model for model in cls.models if model not in ["abuse", "costly", "fake"]]
|
||||
except Exception as e:
|
||||
debug.log(f"PuterJS: Failed to fetch models from API: {e}")
|
||||
cls.models = list(cls.model_aliases.keys())
|
||||
|
|
|
|||
|
|
@ -51,14 +51,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
def get_models(cls):
|
||||
if not cls.models:
|
||||
try:
|
||||
text = requests.get(cls.url).text
|
||||
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
|
||||
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
||||
text = text.replace('void 0', 'null')
|
||||
def add_quotation_mark(match):
|
||||
return f'{match.group(1)}"{match.group(2)}":'
|
||||
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
|
||||
models = json.loads(text)
|
||||
models = requests.get(f"{cls.url}/api/v2/models").json().get("json")
|
||||
cls.text_models = [model["id"] for model in models]
|
||||
cls.models = cls.text_models + cls.image_models
|
||||
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class Api:
|
|||
def get_model_data(provider: ProviderModelMixin, model: str):
|
||||
return {
|
||||
"model": model,
|
||||
"label": model.split(":")[-1] if provider.__name__ == "AnyProvider" else model,
|
||||
"label": model.split(":")[-1] if provider.__name__ == "AnyProvider" and not model.startswith("openrouter:") else model,
|
||||
"default": model == provider.default_model,
|
||||
"vision": model in provider.vision_models,
|
||||
"audio": False if provider.audio_models is None else model in provider.audio_models,
|
||||
|
|
|
|||
|
|
@ -10,22 +10,29 @@ from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
|||
from ..Provider.hf_space import HuggingSpace
|
||||
from ..Provider import __map__
|
||||
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
||||
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
|
||||
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything
|
||||
from ..Provider import Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena, LMArenaBeta
|
||||
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
|
||||
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .. import Provider
|
||||
from .. import models
|
||||
|
||||
MAIN_PROVIERS = [
|
||||
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
|
||||
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
|
||||
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia
|
||||
PROVIERS_LIST_1 = [
|
||||
OpenaiChat, PollinationsAI, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
|
||||
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs,
|
||||
HarProvider, LegacyLMArena, LMArenaBeta, LambdaChat, CopilotAccount, DeepInfraChat,
|
||||
HuggingSpace, HuggingFace, HuggingFaceMedia, PuterJS, Together
|
||||
]
|
||||
|
||||
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]
|
||||
PROVIERS_LIST_2 = [
|
||||
OpenaiChat, CopilotAccount, PollinationsAI, PerplexityLabs, Gemini, Grok, ARTA
|
||||
]
|
||||
|
||||
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS]
|
||||
PROVIERS_LIST_3 = [
|
||||
HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, LegacyLMArena, LMArenaBeta,
|
||||
PuterJS, Together, Cloudflare, HuggingSpace
|
||||
]
|
||||
|
||||
LABELS = {
|
||||
"default": "Default",
|
||||
|
|
@ -33,15 +40,21 @@ LABELS = {
|
|||
"llama": "Meta: LLaMA",
|
||||
"deepseek": "DeepSeek",
|
||||
"qwen": "Alibaba: Qwen",
|
||||
"google": "Google: Gemini / Gemma / Bard",
|
||||
"google": "Google: Gemini / Gemma",
|
||||
"grok": "xAI: Grok",
|
||||
"claude": "Anthropic: Claude",
|
||||
"command": "Cohere: Command",
|
||||
"phi": "Microsoft: Phi / WizardLM",
|
||||
"mistral": "Mistral",
|
||||
"PollinationsAI": "Pollinations AI",
|
||||
"ARTA": "ARTA",
|
||||
"voices": "Voices",
|
||||
"perplexity": "Perplexity Labs",
|
||||
"openrouter": "OpenRouter",
|
||||
"glm": "GLM",
|
||||
"tulu": "Tulu",
|
||||
"reka": "Reka",
|
||||
"hermes": "Hermes",
|
||||
"video": "Video Generation",
|
||||
"image": "Image Generation",
|
||||
"other": "Other Models",
|
||||
|
|
@ -65,31 +78,35 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
continue # Already added
|
||||
|
||||
added = False
|
||||
|
||||
# Check for PollinationsAI models (with prefix)
|
||||
if model.startswith("PollinationsAI:"):
|
||||
groups["PollinationsAI"].append(model)
|
||||
# Check for models with prefix
|
||||
start = model.split(":")[0]
|
||||
if start in ("PollinationsAI", "ARTA", "openrouter"):
|
||||
submodel = model.split(":", maxsplit=1)[1]
|
||||
if submodel in OpenAIFM.voices or submodel in PollinationsAI.audio_models[PollinationsAI.default_audio_model]:
|
||||
groups["voices"].append(submodel)
|
||||
else:
|
||||
groups[start].append(model)
|
||||
added = True
|
||||
# Check for Mistral company models specifically
|
||||
elif model.startswith("mistral") and not any(x in model for x in ["dolphin", "nous", "openhermes"]):
|
||||
groups["mistral"].append(model)
|
||||
added = True
|
||||
elif model.startswith(("mistralai/", "mixtral-", "pixtral-", "ministral-", "codestral-")):
|
||||
elif model.startswith(("pixtral-", "ministral-", "codestral")) or "mistral" in model or "mixtral" in model:
|
||||
groups["mistral"].append(model)
|
||||
added = True
|
||||
# Check for Qwen models
|
||||
elif model.startswith(("qwen", "Qwen/", "qwq", "qvq")):
|
||||
elif model.startswith(("qwen", "Qwen", "qwq", "qvq")):
|
||||
groups["qwen"].append(model)
|
||||
added = True
|
||||
# Check for Microsoft Phi models
|
||||
elif model.startswith(("phi-", "microsoft/")):
|
||||
elif model.startswith(("phi-", "microsoft/")) or "wizardlm" in model.lower():
|
||||
groups["phi"].append(model)
|
||||
added = True
|
||||
# Check for Meta LLaMA models
|
||||
elif model.startswith(("llama-", "meta-llama/", "llama2-", "llama3")):
|
||||
groups["llama"].append(model)
|
||||
added = True
|
||||
elif model == "meta-ai":
|
||||
elif model == "meta-ai" or model.startswith("codellama-"):
|
||||
groups["llama"].append(model)
|
||||
added = True
|
||||
# Check for Google models
|
||||
|
|
@ -100,14 +117,6 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
elif model.startswith(("command-", "CohereForAI/", "c4ai-command")):
|
||||
groups["command"].append(model)
|
||||
added = True
|
||||
# Check for Claude models
|
||||
elif model.startswith("claude-"):
|
||||
groups["claude"].append(model)
|
||||
added = True
|
||||
# Check for Grok models
|
||||
elif model.startswith("grok-"):
|
||||
groups["grok"].append(model)
|
||||
added = True
|
||||
# Check for DeepSeek models
|
||||
elif model.startswith(("deepseek-", "janus-")):
|
||||
groups["deepseek"].append(model)
|
||||
|
|
@ -116,34 +125,27 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
elif model.startswith(("sonar", "sonar-", "pplx-")) or model == "r1-1776":
|
||||
groups["perplexity"].append(model)
|
||||
added = True
|
||||
# Check for OpenAI models
|
||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
|
||||
groups["openai"].append(model)
|
||||
# Check for image models - UPDATED to include flux check
|
||||
elif model in cls.image_models:
|
||||
groups["image"].append(model)
|
||||
added = True
|
||||
# Check for openrouter models
|
||||
elif model.startswith(("openrouter:")):
|
||||
groups["openrouter"].append(model)
|
||||
# Check for OpenAI models
|
||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"):
|
||||
groups["openai"].append(model)
|
||||
added = True
|
||||
# Check for video models
|
||||
elif model in cls.video_models:
|
||||
groups["video"].append(model)
|
||||
added = True
|
||||
# Check for image models - UPDATED to include flux check
|
||||
elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower():
|
||||
groups["image"].append(model)
|
||||
if not added:
|
||||
for group in LABELS.keys():
|
||||
if model == group or group in model:
|
||||
groups[group].append(model)
|
||||
added = True
|
||||
|
||||
break
|
||||
# If not categorized, check for special cases then put in other
|
||||
if not added:
|
||||
# CodeLlama is Meta's model
|
||||
if model.startswith("codellama-"):
|
||||
groups["llama"].append(model)
|
||||
# WizardLM is Microsoft's
|
||||
elif "wizardlm" in model.lower():
|
||||
groups["phi"].append(model)
|
||||
else:
|
||||
groups["other"].append(model)
|
||||
|
||||
return [
|
||||
{"group": LABELS[group], "models": names} for group, names in groups.items()
|
||||
]
|
||||
|
|
@ -174,31 +176,19 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
all_models = [cls.default_model] + list(model_with_providers.keys())
|
||||
|
||||
# Process special providers
|
||||
for provider in SPECIAL_PROVIDERS:
|
||||
for provider in PROVIERS_LIST_2:
|
||||
provider: ProviderType = provider
|
||||
if not provider.working or provider.get_parent() in ignored:
|
||||
continue
|
||||
if provider == CopilotAccount:
|
||||
all_models.extend(list(provider.model_aliases.keys()))
|
||||
elif provider == PollinationsAI:
|
||||
elif provider in [PollinationsAI, ARTA]:
|
||||
all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models])
|
||||
cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models})
|
||||
cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models])
|
||||
cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models])
|
||||
if provider == PollinationsAI:
|
||||
all_models.extend(list(provider.model_aliases.keys()))
|
||||
elif provider == LegacyLMArena:
|
||||
# Add models from LegacyLMArena
|
||||
provider_models = provider.get_models()
|
||||
all_models.extend(provider_models)
|
||||
# Also add model aliases
|
||||
all_models.extend(list(provider.model_aliases.keys()))
|
||||
# Add vision models
|
||||
cls.vision_models.extend(provider.vision_models)
|
||||
elif provider == ARTA:
|
||||
# Add all ARTA models as image models
|
||||
arta_models = provider.get_models()
|
||||
all_models.extend(arta_models)
|
||||
cls.image_models.extend(arta_models)
|
||||
else:
|
||||
all_models.extend(provider.get_models())
|
||||
|
||||
|
|
@ -215,9 +205,9 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
name = name.split("/")[-1].split(":")[0].lower()
|
||||
# Date patterns
|
||||
name = re.sub(r'-\d{4}-\d{2}-\d{2}', '', name)
|
||||
name = re.sub(r'-\d{8}', '', name)
|
||||
name = re.sub(r'-\d{4}', '', name)
|
||||
name = re.sub(r'-\d{3,8}', '', name)
|
||||
name = re.sub(r'-\d{2}-\d{2}', '', name)
|
||||
name = re.sub(r'-[0-9a-f]{8}$', '', name)
|
||||
# Version patterns
|
||||
name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf)$', '', name)
|
||||
# Other replacements
|
||||
|
|
@ -226,23 +216,28 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
name = name.replace("meta-llama-", "llama-")
|
||||
name = name.replace("llama3", "llama-3")
|
||||
name = name.replace("flux.1-", "flux-")
|
||||
name = name.replace("-free", "")
|
||||
name = name.replace("qwen1-", "qwen-1")
|
||||
name = name.replace("qwen2-", "qwen-2")
|
||||
name = name.replace("qwen3-", "qwen-3")
|
||||
name = name.replace("stable-diffusion-3.5-large", "sd-3.5-large")
|
||||
return name
|
||||
|
||||
# Process HAR providers
|
||||
for provider in SPECIAL_PROVIDERS2:
|
||||
for provider in PROVIERS_LIST_3:
|
||||
if not provider.working or provider.get_parent() in ignored:
|
||||
continue
|
||||
new_models = provider.get_models()
|
||||
if provider == HuggingFaceMedia:
|
||||
new_models = provider.video_models
|
||||
|
||||
# Add original models too, not just cleaned names
|
||||
all_models.extend(new_models)
|
||||
|
||||
model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models}
|
||||
if not provider.model_aliases:
|
||||
provider.model_aliases = {}
|
||||
provider.model_aliases.update(model_map)
|
||||
model_map = {}
|
||||
for model in new_models:
|
||||
clean_value = model if model.startswith("openrouter:") else clean_name(model)
|
||||
if clean_value not in model_map:
|
||||
model_map[clean_value] = model
|
||||
if provider.model_aliases:
|
||||
model_map.update(provider.model_aliases)
|
||||
provider.model_aliases = model_map
|
||||
all_models.extend(list(model_map.keys()))
|
||||
|
||||
# Update special model lists with both original and cleaned names
|
||||
|
|
@ -262,7 +257,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
cls.audio_models.update(provider.audio_models)
|
||||
|
||||
# Update model counts
|
||||
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
||||
for model in all_models:
|
||||
count = all_models.count(model)
|
||||
if count > cls.models_count.get(model, 0):
|
||||
cls.models_count.update({model: count})
|
||||
|
||||
# Deduplicate and store
|
||||
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
||||
|
|
@ -320,11 +318,15 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if isinstance(api_key, dict):
|
||||
for provider in api_key:
|
||||
if api_key.get(provider):
|
||||
if provider in __map__ and __map__[provider] not in MAIN_PROVIERS:
|
||||
if provider in __map__ and __map__[provider] not in PROVIERS_LIST_1:
|
||||
extra_providers.append(__map__[provider])
|
||||
for provider in MAIN_PROVIERS + extra_providers:
|
||||
for provider in PROVIERS_LIST_1 + extra_providers:
|
||||
if provider.working:
|
||||
if not model or model in provider.get_models() or model in provider.model_aliases:
|
||||
provider_api_key = api_key
|
||||
if isinstance(api_key, dict):
|
||||
provider_api_key = api_key.get(provider.get_parent())
|
||||
provider_models = provider.get_models(api_key=provider_api_key) if provider_api_key else provider.get_models()
|
||||
if not model or model in provider_models or provider.model_aliases and model in provider.model_aliases:
|
||||
providers.append(provider)
|
||||
if model in models.__models__:
|
||||
for provider in models.__models__[model][1]:
|
||||
|
|
|
|||
|
|
@ -367,20 +367,19 @@ class ProviderModelMixin:
|
|||
@classmethod
|
||||
def get_models(cls, **kwargs) -> list[str]:
|
||||
if not cls.models and cls.default_model is not None:
|
||||
return [cls.default_model]
|
||||
cls.models = [cls.default_model]
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
if not model and cls.default_model is not None:
|
||||
model = cls.default_model
|
||||
elif model in cls.model_aliases:
|
||||
if model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
else:
|
||||
if model not in cls.get_models(**kwargs) and cls.models:
|
||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||
cls.last_model = model
|
||||
debug.last_model = model
|
||||
return model
|
||||
|
||||
class RaiseErrorMixin():
|
||||
|
|
|
|||
|
|
@ -55,16 +55,16 @@ class IterListProvider(BaseRetryProvider):
|
|||
self.last_provider = provider
|
||||
if not model:
|
||||
model = getattr(provider, "default_model", None)
|
||||
model = provider.model_aliases.get(model, model) if hasattr(provider, "model_aliases") else model
|
||||
debug.log(f"Using {provider.__name__} provider with model {model}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=model)
|
||||
alias = provider.model_aliases.get(model, model) if hasattr(provider, "model_aliases") else model
|
||||
debug.log(f"Using {provider.__name__} provider with model {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||
extra_body = kwargs.copy()
|
||||
if isinstance(api_key, dict):
|
||||
api_key = api_key.get(provider.get_parent())
|
||||
if api_key:
|
||||
extra_body["api_key"] = api_key
|
||||
try:
|
||||
response = provider.create_function(model, messages, stream=stream, **extra_body)
|
||||
response = provider.create_function(alias, messages, stream=stream, **extra_body)
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -48,6 +48,14 @@ def get_github_version(repo: str) -> str:
|
|||
except requests.RequestException as e:
|
||||
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
||||
|
||||
def get_git_version() -> str:
|
||||
# Read from git repository
|
||||
try:
|
||||
command = ["git", "describe", "--tags", "--abbrev=0"]
|
||||
return check_output(command, text=True, stderr=PIPE).strip()
|
||||
except CalledProcessError:
|
||||
return None
|
||||
|
||||
class VersionUtils:
|
||||
"""
|
||||
Utility class for managing and comparing package versions of 'g4f'.
|
||||
|
|
@ -78,14 +86,7 @@ class VersionUtils:
|
|||
if version:
|
||||
return version
|
||||
|
||||
# Read from git repository
|
||||
try:
|
||||
command = ["git", "describe", "--tags", "--abbrev=0"]
|
||||
return check_output(command, text=True, stderr=PIPE).strip()
|
||||
except CalledProcessError:
|
||||
pass
|
||||
|
||||
return None
|
||||
return get_git_version()
|
||||
|
||||
@property
|
||||
def latest_version(self) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue