fix: enhance retry logic and parameter handling in commit and provider code

- Added `--max-retries` argument to `parse_arguments()` in `commit.py` with default `MAX_RETRIES`
- Updated `generate_commit_message()` to accept a `max_retries` parameter and iterate accordingly
- Included check to raise immediately if `max_retries` is 1 within `generate_commit_message()`
- Passed `args.max_retries` when calling `generate_commit_message()` in `main()`
- In `g4f/Provider/har/__init__.py`, imported `ResponseError` and added check for network error to raise `ResponseError`
- In `g4f/Provider/hf_space/Qwen_Qwen_3.py`, changed default model string and updated system prompt handling to use `get_system_prompt()`
- In `g4f/Provider/needs_auth/LMArenaBeta.py`, modified callback to wait for cookie and turnstile response
- In `g4f/Provider/needs_auth/PuterJS.py`, adjusted `get_models()` to filter out certain models
- In `g4f/gui/server/api.py`, adjusted `get_model_data()` to handle models starting with "openrouter:"
- In `g4f/providers/any_provider.py`, imported and used `ResponseError`; added logic to process `model_aliases` with updated model name resolution
- Refined model name cleaning logic to handle additional patterns and replaced multiple regex patterns to better match version strings
- Updated list of providers `PROVIERS_LIST_1`, `PROVIERS_LIST_2`, `PROVIERS_LIST_3`, and their usage to include new providers and adjust filtering
- In `g4f/version.py`, added `get_git_version()` function, retrieved version with `git describe` command, instead of only relying on `get_github_version()`, increasing robustness
This commit is contained in:
hlohaus 2025-06-13 05:32:55 +02:00
parent 67231e8c40
commit c12227a1cd
11 changed files with 125 additions and 122 deletions

View file

@ -128,6 +128,8 @@ def parse_arguments():
help="List available AI models and exit") help="List available AI models and exit")
parser.add_argument("--repo", type=str, default=".", parser.add_argument("--repo", type=str, default=".",
help="Git repository path (default: current directory)") help="Git repository path (default: current directory)")
parser.add_argument("--max-retries", type=int, default=MAX_RETRIES,
help="Maximum number of retries for AI generation (default: 3)")
return parser.parse_args() return parser.parse_args()
@ -288,7 +290,7 @@ def show_spinner(duration: int = None):
stop_spinner.set() stop_spinner.set()
raise raise
def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optional[str]: def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL, max_retries: int = MAX_RETRIES) -> Optional[str]:
"""Generate a commit message based on the git diff""" """Generate a commit message based on the git diff"""
if not diff_text or diff_text.strip() == "": if not diff_text or diff_text.strip() == "":
return "No changes staged for commit" return "No changes staged for commit"
@ -324,7 +326,7 @@ def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optio
IMPORTANT: Be 100% factual. Only mention code that was actually changed. Never invent or assume changes not shown in the diff. If unsure about a change's purpose, describe what changed rather than why. Output nothing except for the commit message, and don't surround it in quotes. IMPORTANT: Be 100% factual. Only mention code that was actually changed. Never invent or assume changes not shown in the diff. If unsure about a change's purpose, describe what changed rather than why. Output nothing except for the commit message, and don't surround it in quotes.
""" """
for attempt in range(MAX_RETRIES): for attempt in range(max_retries):
try: try:
# Start spinner # Start spinner
spinner = show_spinner() spinner = show_spinner()
@ -352,7 +354,8 @@ def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optio
spinner.set() spinner.set()
sys.stdout.write("\r" + " " * 50 + "\r") sys.stdout.write("\r" + " " * 50 + "\r")
sys.stdout.flush() sys.stdout.flush()
if max_retries == 1:
raise e # If no retries, raise immediately
print(f"Error generating commit message (attempt {attempt+1}/{MAX_RETRIES}): {e}") print(f"Error generating commit message (attempt {attempt+1}/{MAX_RETRIES}): {e}")
if attempt < MAX_RETRIES - 1: if attempt < MAX_RETRIES - 1:
print(f"Retrying in {RETRY_DELAY} seconds...") print(f"Retrying in {RETRY_DELAY} seconds...")
@ -464,7 +467,7 @@ def main():
sys.exit(0) sys.exit(0)
print(f"Using model: {args.model}") print(f"Using model: {args.model}")
commit_message = generate_commit_message(diff, args.model) commit_message = generate_commit_message(diff, args.model, args.max_retries)
if not commit_message: if not commit_message:
print("Failed to generate commit message after multiple attempts.") print("Failed to generate commit message after multiple attempts.")

View file

@ -10,6 +10,7 @@ from ...requests import DEFAULT_HEADERS, StreamSession, StreamResponse, FormData
from ...providers.response import JsonConversation from ...providers.response import JsonConversation
from ...tools.media import merge_media from ...tools.media import merge_media
from ...image import to_bytes, is_accepted_format from ...image import to_bytes, is_accepted_format
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_last_user_message from ..helper import get_last_user_message
from ..openai.har_file import get_headers from ..openai.har_file import get_headers
@ -139,6 +140,8 @@ class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
if not line.startswith(b"data: "): if not line.startswith(b"data: "):
continue continue
for content in find_str(json.loads(line[6:]), 3): for content in find_str(json.loads(line[6:]), 3):
if "**NETWORK ERROR DUE TO HIGH TRAFFIC." in content:
raise ResponseError(content)
if content == '<span class="cursor"></span> ' or content == 'update': if content == '<span class="cursor"></span> ' or content == 'update':
continue continue
if content.endswith(""): if content.endswith(""):

View file

@ -9,7 +9,7 @@ from ...providers.response import Reasoning, JsonConversation
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...errors import ModelNotFoundError from ...errors import ModelNotFoundError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_last_user_message from ..helper import get_last_user_message, get_system_prompt
from ... import debug from ... import debug
@ -22,19 +22,19 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
supports_stream = True supports_stream = True
supports_system_message = True supports_system_message = True
default_model = "qwen3-235b-a22b" default_model = "qwen-3-235b"
models = { models = {
default_model, default_model,
"qwen3-32b", "qwen-3-32b",
"qwen3-30b-a3b", "qwen-3-30b-a3b",
"qwen3-14b", "qwen-3-14b",
"qwen3-8b", "qwen-3-8b",
"qwen3-4b", "qwen-3-4b",
"qwen3-1.7b", "qwen-3-1.7b",
"qwen3-0.6b", "qwen-3-0.6b",
} }
model_aliases = { model_aliases = {
"qwen-3-235b": default_model, "qwen-3-235b": "qwen3-235b-a22b",
"qwen-3-30b": "qwen3-30b-a3b", "qwen-3-30b": "qwen3-30b-a3b",
"qwen-3-32b": "qwen3-32b", "qwen-3-32b": "qwen3-32b",
"qwen-3-14b": "qwen3-14b", "qwen-3-14b": "qwen3-14b",
@ -76,12 +76,12 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
'Cache-Control': 'no-cache', 'Cache-Control': 'no-cache',
} }
sys_prompt = "\n".join([message['content'] for message in messages if message['role'] == 'system']) system_prompt = get_system_prompt(messages)
sys_prompt = sys_prompt if sys_prompt else "You are a helpful and harmless assistant." system_prompt = system_prompt if system_prompt else "You are a helpful and harmless assistant."
payload_join = {"data": [ payload_join = {"data": [
get_last_user_message(messages), get_last_user_message(messages),
{"thinking_budget": thinking_budget, "model": cls.get_model(model), "sys_prompt": sys_prompt}, None, None], {"thinking_budget": thinking_budget, "model": cls.get_model(model), "sys_prompt": system_prompt}, None, None],
"event_data": None, "fn_index": 13, "trigger_id": 31, "session_hash": conversation.session_hash "event_data": None, "fn_index": 13, "trigger_id": 31, "session_hash": conversation.session_hash
} }

View file

@ -78,7 +78,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
label = "LMArena Beta" label = "LMArena Beta"
url = "https://beta.lmarena.ai" url = "https://beta.lmarena.ai"
api_endpoint = "https://beta.lmarena.ai/api/stream/create-evaluation" api_endpoint = "https://beta.lmarena.ai/api/stream/create-evaluation"
working = True working = has_nodriver
default_model = list(text_models.keys())[0] default_model = list(text_models.keys())[0]
models = list(text_models) + list(image_models) models = list(text_models) + list(image_models)
@ -102,7 +102,8 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
async def callback(page): async def callback(page):
while not await page.evaluate('document.cookie.indexOf("arena-auth-prod-v1") >= 0'): while not await page.evaluate('document.cookie.indexOf("arena-auth-prod-v1") >= 0'):
await asyncio.sleep(1) await asyncio.sleep(1)
await asyncio.sleep(5) while await page.evaluate('document.querySelector(\'[name="cf-turnstile-response"]\').length > 0') :
await asyncio.sleep(1)
args = await get_args_from_nodriver(cls.url, proxy=proxy, callback=callback) args = await get_args_from_nodriver(cls.url, proxy=proxy, callback=callback)
except (RuntimeError, FileNotFoundError) as e: except (RuntimeError, FileNotFoundError) as e:
debug.log(f"Nodriver is not available: {type(e).__name__}: {e}") debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")

View file

@ -257,14 +257,15 @@ class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
"lfm-7b": "openrouter:liquid/lfm-7b", "lfm-7b": "openrouter:liquid/lfm-7b",
"lfm-3b": "openrouter:liquid/lfm-3b", "lfm-3b": "openrouter:liquid/lfm-3b",
"lfm-40b": "openrouter:liquid/lfm-40b", "lfm-40b": "openrouter:liquid/lfm-40b",
}
}
@classmethod @classmethod
def get_models(cls) -> list[str]: def get_models(cls, api_key: str = None) -> list[str]:
if not cls.models: if not cls.models:
try: try:
url = "https://api.puter.com/puterai/chat/models/" url = "https://api.puter.com/puterai/chat/models/"
cls.models = requests.get(url).json().get("models", []) cls.models = requests.get(url).json().get("models", [])
cls.models = [model for model in cls.models if model not in ["abuse", "costly", "fake"]]
except Exception as e: except Exception as e:
debug.log(f"PuterJS: Failed to fetch models from API: {e}") debug.log(f"PuterJS: Failed to fetch models from API: {e}")
cls.models = list(cls.model_aliases.keys()) cls.models = list(cls.model_aliases.keys())

View file

@ -51,14 +51,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
def get_models(cls): def get_models(cls):
if not cls.models: if not cls.models:
try: try:
text = requests.get(cls.url).text models = requests.get(f"{cls.url}/api/v2/models").json().get("json")
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = text.replace('void 0', 'null')
def add_quotation_mark(match):
return f'{match.group(1)}"{match.group(2)}":'
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
models = json.loads(text)
cls.text_models = [model["id"] for model in models] cls.text_models = [model["id"] for model in models]
cls.models = cls.text_models + cls.image_models cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]] cls.vision_models = [model["id"] for model in models if model["multimodal"]]

View file

@ -43,7 +43,7 @@ class Api:
def get_model_data(provider: ProviderModelMixin, model: str): def get_model_data(provider: ProviderModelMixin, model: str):
return { return {
"model": model, "model": model,
"label": model.split(":")[-1] if provider.__name__ == "AnyProvider" else model, "label": model.split(":")[-1] if provider.__name__ == "AnyProvider" and not model.startswith("openrouter:") else model,
"default": model == provider.default_model, "default": model == provider.default_model,
"vision": model in provider.vision_models, "vision": model in provider.vision_models,
"audio": False if provider.audio_models is None else model in provider.audio_models, "audio": False if provider.audio_models is None else model in provider.audio_models,

View file

@ -10,22 +10,29 @@ from ..Provider.needs_auth import OpenaiChat, CopilotAccount
from ..Provider.hf_space import HuggingSpace from ..Provider.hf_space import HuggingSpace
from ..Provider import __map__ from ..Provider import __map__
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything
from ..Provider import Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena, LMArenaBeta
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .. import Provider from .. import Provider
from .. import models from .. import models
MAIN_PROVIERS = [ PROVIERS_LIST_1 = [
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenaiChat, PollinationsAI, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs,
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia HarProvider, LegacyLMArena, LMArenaBeta, LambdaChat, CopilotAccount, DeepInfraChat,
HuggingSpace, HuggingFace, HuggingFaceMedia, PuterJS, Together
] ]
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA] PROVIERS_LIST_2 = [
OpenaiChat, CopilotAccount, PollinationsAI, PerplexityLabs, Gemini, Grok, ARTA
]
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS] PROVIERS_LIST_3 = [
HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, LegacyLMArena, LMArenaBeta,
PuterJS, Together, Cloudflare, HuggingSpace
]
LABELS = { LABELS = {
"default": "Default", "default": "Default",
@ -33,15 +40,21 @@ LABELS = {
"llama": "Meta: LLaMA", "llama": "Meta: LLaMA",
"deepseek": "DeepSeek", "deepseek": "DeepSeek",
"qwen": "Alibaba: Qwen", "qwen": "Alibaba: Qwen",
"google": "Google: Gemini / Gemma / Bard", "google": "Google: Gemini / Gemma",
"grok": "xAI: Grok", "grok": "xAI: Grok",
"claude": "Anthropic: Claude", "claude": "Anthropic: Claude",
"command": "Cohere: Command", "command": "Cohere: Command",
"phi": "Microsoft: Phi / WizardLM", "phi": "Microsoft: Phi / WizardLM",
"mistral": "Mistral", "mistral": "Mistral",
"PollinationsAI": "Pollinations AI", "PollinationsAI": "Pollinations AI",
"ARTA": "ARTA",
"voices": "Voices",
"perplexity": "Perplexity Labs", "perplexity": "Perplexity Labs",
"openrouter": "OpenRouter", "openrouter": "OpenRouter",
"glm": "GLM",
"tulu": "Tulu",
"reka": "Reka",
"hermes": "Hermes",
"video": "Video Generation", "video": "Video Generation",
"image": "Image Generation", "image": "Image Generation",
"other": "Other Models", "other": "Other Models",
@ -65,31 +78,35 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
continue # Already added continue # Already added
added = False added = False
# Check for models with prefix
# Check for PollinationsAI models (with prefix) start = model.split(":")[0]
if model.startswith("PollinationsAI:"): if start in ("PollinationsAI", "ARTA", "openrouter"):
groups["PollinationsAI"].append(model) submodel = model.split(":", maxsplit=1)[1]
if submodel in OpenAIFM.voices or submodel in PollinationsAI.audio_models[PollinationsAI.default_audio_model]:
groups["voices"].append(submodel)
else:
groups[start].append(model)
added = True added = True
# Check for Mistral company models specifically # Check for Mistral company models specifically
elif model.startswith("mistral") and not any(x in model for x in ["dolphin", "nous", "openhermes"]): elif model.startswith("mistral") and not any(x in model for x in ["dolphin", "nous", "openhermes"]):
groups["mistral"].append(model) groups["mistral"].append(model)
added = True added = True
elif model.startswith(("mistralai/", "mixtral-", "pixtral-", "ministral-", "codestral-")): elif model.startswith(("pixtral-", "ministral-", "codestral")) or "mistral" in model or "mixtral" in model:
groups["mistral"].append(model) groups["mistral"].append(model)
added = True added = True
# Check for Qwen models # Check for Qwen models
elif model.startswith(("qwen", "Qwen/", "qwq", "qvq")): elif model.startswith(("qwen", "Qwen", "qwq", "qvq")):
groups["qwen"].append(model) groups["qwen"].append(model)
added = True added = True
# Check for Microsoft Phi models # Check for Microsoft Phi models
elif model.startswith(("phi-", "microsoft/")): elif model.startswith(("phi-", "microsoft/")) or "wizardlm" in model.lower():
groups["phi"].append(model) groups["phi"].append(model)
added = True added = True
# Check for Meta LLaMA models # Check for Meta LLaMA models
elif model.startswith(("llama-", "meta-llama/", "llama2-", "llama3")): elif model.startswith(("llama-", "meta-llama/", "llama2-", "llama3")):
groups["llama"].append(model) groups["llama"].append(model)
added = True added = True
elif model == "meta-ai": elif model == "meta-ai" or model.startswith("codellama-"):
groups["llama"].append(model) groups["llama"].append(model)
added = True added = True
# Check for Google models # Check for Google models
@ -100,14 +117,6 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
elif model.startswith(("command-", "CohereForAI/", "c4ai-command")): elif model.startswith(("command-", "CohereForAI/", "c4ai-command")):
groups["command"].append(model) groups["command"].append(model)
added = True added = True
# Check for Claude models
elif model.startswith("claude-"):
groups["claude"].append(model)
added = True
# Check for Grok models
elif model.startswith("grok-"):
groups["grok"].append(model)
added = True
# Check for DeepSeek models # Check for DeepSeek models
elif model.startswith(("deepseek-", "janus-")): elif model.startswith(("deepseek-", "janus-")):
groups["deepseek"].append(model) groups["deepseek"].append(model)
@ -116,34 +125,27 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
elif model.startswith(("sonar", "sonar-", "pplx-")) or model == "r1-1776": elif model.startswith(("sonar", "sonar-", "pplx-")) or model == "r1-1776":
groups["perplexity"].append(model) groups["perplexity"].append(model)
added = True added = True
# Check for OpenAI models # Check for image models - UPDATED to include flux check
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"): elif model in cls.image_models:
groups["openai"].append(model) groups["image"].append(model)
added = True added = True
# Check for openrouter models # Check for OpenAI models
elif model.startswith(("openrouter:")): elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"):
groups["openrouter"].append(model) groups["openai"].append(model)
added = True added = True
# Check for video models # Check for video models
elif model in cls.video_models: elif model in cls.video_models:
groups["video"].append(model) groups["video"].append(model)
added = True added = True
# Check for image models - UPDATED to include flux check if not added:
elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower(): for group in LABELS.keys():
groups["image"].append(model) if model == group or group in model:
added = True groups[group].append(model)
added = True
break
# If not categorized, check for special cases then put in other # If not categorized, check for special cases then put in other
if not added: if not added:
# CodeLlama is Meta's model groups["other"].append(model)
if model.startswith("codellama-"):
groups["llama"].append(model)
# WizardLM is Microsoft's
elif "wizardlm" in model.lower():
groups["phi"].append(model)
else:
groups["other"].append(model)
return [ return [
{"group": LABELS[group], "models": names} for group, names in groups.items() {"group": LABELS[group], "models": names} for group, names in groups.items()
] ]
@ -174,31 +176,19 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
all_models = [cls.default_model] + list(model_with_providers.keys()) all_models = [cls.default_model] + list(model_with_providers.keys())
# Process special providers # Process special providers
for provider in SPECIAL_PROVIDERS: for provider in PROVIERS_LIST_2:
provider: ProviderType = provider provider: ProviderType = provider
if not provider.working or provider.get_parent() in ignored: if not provider.working or provider.get_parent() in ignored:
continue continue
if provider == CopilotAccount: if provider == CopilotAccount:
all_models.extend(list(provider.model_aliases.keys())) all_models.extend(list(provider.model_aliases.keys()))
elif provider == PollinationsAI: elif provider in [PollinationsAI, ARTA]:
all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models]) all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models])
cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models}) cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models})
cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models]) cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models])
cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models]) cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models])
all_models.extend(list(provider.model_aliases.keys())) if provider == PollinationsAI:
elif provider == LegacyLMArena: all_models.extend(list(provider.model_aliases.keys()))
# Add models from LegacyLMArena
provider_models = provider.get_models()
all_models.extend(provider_models)
# Also add model aliases
all_models.extend(list(provider.model_aliases.keys()))
# Add vision models
cls.vision_models.extend(provider.vision_models)
elif provider == ARTA:
# Add all ARTA models as image models
arta_models = provider.get_models()
all_models.extend(arta_models)
cls.image_models.extend(arta_models)
else: else:
all_models.extend(provider.get_models()) all_models.extend(provider.get_models())
@ -215,9 +205,9 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
name = name.split("/")[-1].split(":")[0].lower() name = name.split("/")[-1].split(":")[0].lower()
# Date patterns # Date patterns
name = re.sub(r'-\d{4}-\d{2}-\d{2}', '', name) name = re.sub(r'-\d{4}-\d{2}-\d{2}', '', name)
name = re.sub(r'-\d{8}', '', name) name = re.sub(r'-\d{3,8}', '', name)
name = re.sub(r'-\d{4}', '', name)
name = re.sub(r'-\d{2}-\d{2}', '', name) name = re.sub(r'-\d{2}-\d{2}', '', name)
name = re.sub(r'-[0-9a-f]{8}$', '', name)
# Version patterns # Version patterns
name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf)$', '', name) name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf)$', '', name)
# Other replacements # Other replacements
@ -226,23 +216,28 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
name = name.replace("meta-llama-", "llama-") name = name.replace("meta-llama-", "llama-")
name = name.replace("llama3", "llama-3") name = name.replace("llama3", "llama-3")
name = name.replace("flux.1-", "flux-") name = name.replace("flux.1-", "flux-")
name = name.replace("-free", "")
name = name.replace("qwen1-", "qwen-1")
name = name.replace("qwen2-", "qwen-2")
name = name.replace("qwen3-", "qwen-3")
name = name.replace("stable-diffusion-3.5-large", "sd-3.5-large")
return name return name
# Process HAR providers # Process HAR providers
for provider in SPECIAL_PROVIDERS2: for provider in PROVIERS_LIST_3:
if not provider.working or provider.get_parent() in ignored: if not provider.working or provider.get_parent() in ignored:
continue continue
new_models = provider.get_models() new_models = provider.get_models()
if provider == HuggingFaceMedia: if provider == HuggingFaceMedia:
new_models = provider.video_models new_models = provider.video_models
model_map = {}
# Add original models too, not just cleaned names for model in new_models:
all_models.extend(new_models) clean_value = model if model.startswith("openrouter:") else clean_name(model)
if clean_value not in model_map:
model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models} model_map[clean_value] = model
if not provider.model_aliases: if provider.model_aliases:
provider.model_aliases = {} model_map.update(provider.model_aliases)
provider.model_aliases.update(model_map) provider.model_aliases = model_map
all_models.extend(list(model_map.keys())) all_models.extend(list(model_map.keys()))
# Update special model lists with both original and cleaned names # Update special model lists with both original and cleaned names
@ -262,7 +257,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
cls.audio_models.update(provider.audio_models) cls.audio_models.update(provider.audio_models)
# Update model counts # Update model counts
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)}) for model in all_models:
count = all_models.count(model)
if count > cls.models_count.get(model, 0):
cls.models_count.update({model: count})
# Deduplicate and store # Deduplicate and store
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models])) cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
@ -320,11 +318,15 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
if isinstance(api_key, dict): if isinstance(api_key, dict):
for provider in api_key: for provider in api_key:
if api_key.get(provider): if api_key.get(provider):
if provider in __map__ and __map__[provider] not in MAIN_PROVIERS: if provider in __map__ and __map__[provider] not in PROVIERS_LIST_1:
extra_providers.append(__map__[provider]) extra_providers.append(__map__[provider])
for provider in MAIN_PROVIERS + extra_providers: for provider in PROVIERS_LIST_1 + extra_providers:
if provider.working: if provider.working:
if not model or model in provider.get_models() or model in provider.model_aliases: provider_api_key = api_key
if isinstance(api_key, dict):
provider_api_key = api_key.get(provider.get_parent())
provider_models = provider.get_models(api_key=provider_api_key) if provider_api_key else provider.get_models()
if not model or model in provider_models or provider.model_aliases and model in provider.model_aliases:
providers.append(provider) providers.append(provider)
if model in models.__models__: if model in models.__models__:
for provider in models.__models__[model][1]: for provider in models.__models__[model][1]:
@ -334,7 +336,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
if len(providers) == 0: if len(providers) == 0:
raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.") raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.")
async for chunk in IterListProvider(providers).create_async_generator( async for chunk in IterListProvider(providers).create_async_generator(
model, model,
messages, messages,

View file

@ -367,20 +367,19 @@ class ProviderModelMixin:
@classmethod @classmethod
def get_models(cls, **kwargs) -> list[str]: def get_models(cls, **kwargs) -> list[str]:
if not cls.models and cls.default_model is not None: if not cls.models and cls.default_model is not None:
return [cls.default_model] cls.models = [cls.default_model]
return cls.models return cls.models
@classmethod @classmethod
def get_model(cls, model: str, **kwargs) -> str: def get_model(cls, model: str, **kwargs) -> str:
if not model and cls.default_model is not None: if not model and cls.default_model is not None:
model = cls.default_model model = cls.default_model
elif model in cls.model_aliases: if model in cls.model_aliases:
model = cls.model_aliases[model] model = cls.model_aliases[model]
else: else:
if model not in cls.get_models(**kwargs) and cls.models: if model not in cls.get_models(**kwargs) and cls.models:
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}") raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
cls.last_model = model cls.last_model = model
debug.last_model = model
return model return model
class RaiseErrorMixin(): class RaiseErrorMixin():

View file

@ -55,16 +55,16 @@ class IterListProvider(BaseRetryProvider):
self.last_provider = provider self.last_provider = provider
if not model: if not model:
model = getattr(provider, "default_model", None) model = getattr(provider, "default_model", None)
model = provider.model_aliases.get(model, model) if hasattr(provider, "model_aliases") else model alias = provider.model_aliases.get(model, model) if hasattr(provider, "model_aliases") else model
debug.log(f"Using {provider.__name__} provider with model {model}") debug.log(f"Using {provider.__name__} provider with model {alias}")
yield ProviderInfo(**provider.get_dict(), model=model) yield ProviderInfo(**provider.get_dict(), model=alias)
extra_body = kwargs.copy() extra_body = kwargs.copy()
if isinstance(api_key, dict): if isinstance(api_key, dict):
api_key = api_key.get(provider.get_parent()) api_key = api_key.get(provider.get_parent())
if api_key: if api_key:
extra_body["api_key"] = api_key extra_body["api_key"] = api_key
try: try:
response = provider.create_function(model, messages, stream=stream, **extra_body) response = provider.create_function(alias, messages, stream=stream, **extra_body)
for chunk in response: for chunk in response:
if chunk: if chunk:
yield chunk yield chunk

View file

@ -48,6 +48,14 @@ def get_github_version(repo: str) -> str:
except requests.RequestException as e: except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}") raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
def get_git_version() -> str:
# Read from git repository
try:
command = ["git", "describe", "--tags", "--abbrev=0"]
return check_output(command, text=True, stderr=PIPE).strip()
except CalledProcessError:
return None
class VersionUtils: class VersionUtils:
""" """
Utility class for managing and comparing package versions of 'g4f'. Utility class for managing and comparing package versions of 'g4f'.
@ -78,14 +86,7 @@ class VersionUtils:
if version: if version:
return version return version
# Read from git repository return get_git_version()
try:
command = ["git", "describe", "--tags", "--abbrev=0"]
return check_output(command, text=True, stderr=PIPE).strip()
except CalledProcessError:
pass
return None
@property @property
def latest_version(self) -> str: def latest_version(self) -> str: