mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
fix: enhance retry logic and parameter handling in commit and provider code
- Added `--max-retries` argument to `parse_arguments()` in `commit.py` with default `MAX_RETRIES` - Updated `generate_commit_message()` to accept a `max_retries` parameter and iterate accordingly - Included check to raise immediately if `max_retries` is 1 within `generate_commit_message()` - Passed `args.max_retries` when calling `generate_commit_message()` in `main()` - In `g4f/Provider/har/__init__.py`, imported `ResponseError` and added check for network error to raise `ResponseError` - In `g4f/Provider/hf_space/Qwen_Qwen_3.py`, changed default model string and updated system prompt handling to use `get_system_prompt()` - In `g4f/Provider/needs_auth/LMArenaBeta.py`, modified callback to wait for cookie and turnstile response - In `g4f/Provider/needs_auth/PuterJS.py`, adjusted `get_models()` to filter out certain models - In `g4f/gui/server/api.py`, adjusted `get_model_data()` to handle models starting with "openrouter:" - In `g4f/providers/any_provider.py`, imported and used `ResponseError`; added logic to process `model_aliases` with updated model name resolution - Refined model name cleaning logic to handle additional patterns and replaced multiple regex patterns to better match version strings - Updated list of providers `PROVIERS_LIST_1`, `PROVIERS_LIST_2`, `PROVIERS_LIST_3`, and their usage to include new providers and adjust filtering - In `g4f/version.py`, added `get_git_version()` function, retrieved version with `git describe` command, instead of only relying on `get_github_version()`, increasing robustness
This commit is contained in:
parent
67231e8c40
commit
c12227a1cd
11 changed files with 125 additions and 122 deletions
|
|
@ -128,6 +128,8 @@ def parse_arguments():
|
||||||
help="List available AI models and exit")
|
help="List available AI models and exit")
|
||||||
parser.add_argument("--repo", type=str, default=".",
|
parser.add_argument("--repo", type=str, default=".",
|
||||||
help="Git repository path (default: current directory)")
|
help="Git repository path (default: current directory)")
|
||||||
|
parser.add_argument("--max-retries", type=int, default=MAX_RETRIES,
|
||||||
|
help="Maximum number of retries for AI generation (default: 3)")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -288,7 +290,7 @@ def show_spinner(duration: int = None):
|
||||||
stop_spinner.set()
|
stop_spinner.set()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optional[str]:
|
def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL, max_retries: int = MAX_RETRIES) -> Optional[str]:
|
||||||
"""Generate a commit message based on the git diff"""
|
"""Generate a commit message based on the git diff"""
|
||||||
if not diff_text or diff_text.strip() == "":
|
if not diff_text or diff_text.strip() == "":
|
||||||
return "No changes staged for commit"
|
return "No changes staged for commit"
|
||||||
|
|
@ -324,7 +326,7 @@ def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optio
|
||||||
IMPORTANT: Be 100% factual. Only mention code that was actually changed. Never invent or assume changes not shown in the diff. If unsure about a change's purpose, describe what changed rather than why. Output nothing except for the commit message, and don't surround it in quotes.
|
IMPORTANT: Be 100% factual. Only mention code that was actually changed. Never invent or assume changes not shown in the diff. If unsure about a change's purpose, describe what changed rather than why. Output nothing except for the commit message, and don't surround it in quotes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for attempt in range(MAX_RETRIES):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
# Start spinner
|
# Start spinner
|
||||||
spinner = show_spinner()
|
spinner = show_spinner()
|
||||||
|
|
@ -352,7 +354,8 @@ def generate_commit_message(diff_text: str, model: str = DEFAULT_MODEL) -> Optio
|
||||||
spinner.set()
|
spinner.set()
|
||||||
sys.stdout.write("\r" + " " * 50 + "\r")
|
sys.stdout.write("\r" + " " * 50 + "\r")
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
if max_retries == 1:
|
||||||
|
raise e # If no retries, raise immediately
|
||||||
print(f"Error generating commit message (attempt {attempt+1}/{MAX_RETRIES}): {e}")
|
print(f"Error generating commit message (attempt {attempt+1}/{MAX_RETRIES}): {e}")
|
||||||
if attempt < MAX_RETRIES - 1:
|
if attempt < MAX_RETRIES - 1:
|
||||||
print(f"Retrying in {RETRY_DELAY} seconds...")
|
print(f"Retrying in {RETRY_DELAY} seconds...")
|
||||||
|
|
@ -464,7 +467,7 @@ def main():
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
print(f"Using model: {args.model}")
|
print(f"Using model: {args.model}")
|
||||||
commit_message = generate_commit_message(diff, args.model)
|
commit_message = generate_commit_message(diff, args.model, args.max_retries)
|
||||||
|
|
||||||
if not commit_message:
|
if not commit_message:
|
||||||
print("Failed to generate commit message after multiple attempts.")
|
print("Failed to generate commit message after multiple attempts.")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from ...requests import DEFAULT_HEADERS, StreamSession, StreamResponse, FormData
|
||||||
from ...providers.response import JsonConversation
|
from ...providers.response import JsonConversation
|
||||||
from ...tools.media import merge_media
|
from ...tools.media import merge_media
|
||||||
from ...image import to_bytes, is_accepted_format
|
from ...image import to_bytes, is_accepted_format
|
||||||
|
from ...errors import ResponseError
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import get_last_user_message
|
from ..helper import get_last_user_message
|
||||||
from ..openai.har_file import get_headers
|
from ..openai.har_file import get_headers
|
||||||
|
|
@ -139,6 +140,8 @@ class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if not line.startswith(b"data: "):
|
if not line.startswith(b"data: "):
|
||||||
continue
|
continue
|
||||||
for content in find_str(json.loads(line[6:]), 3):
|
for content in find_str(json.loads(line[6:]), 3):
|
||||||
|
if "**NETWORK ERROR DUE TO HIGH TRAFFIC." in content:
|
||||||
|
raise ResponseError(content)
|
||||||
if content == '<span class="cursor"></span> ' or content == 'update':
|
if content == '<span class="cursor"></span> ' or content == 'update':
|
||||||
continue
|
continue
|
||||||
if content.endswith("▌"):
|
if content.endswith("▌"):
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from ...providers.response import Reasoning, JsonConversation
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...errors import ModelNotFoundError
|
from ...errors import ModelNotFoundError
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import get_last_user_message
|
from ..helper import get_last_user_message, get_system_prompt
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,19 +22,19 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
|
|
||||||
default_model = "qwen3-235b-a22b"
|
default_model = "qwen-3-235b"
|
||||||
models = {
|
models = {
|
||||||
default_model,
|
default_model,
|
||||||
"qwen3-32b",
|
"qwen-3-32b",
|
||||||
"qwen3-30b-a3b",
|
"qwen-3-30b-a3b",
|
||||||
"qwen3-14b",
|
"qwen-3-14b",
|
||||||
"qwen3-8b",
|
"qwen-3-8b",
|
||||||
"qwen3-4b",
|
"qwen-3-4b",
|
||||||
"qwen3-1.7b",
|
"qwen-3-1.7b",
|
||||||
"qwen3-0.6b",
|
"qwen-3-0.6b",
|
||||||
}
|
}
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"qwen-3-235b": default_model,
|
"qwen-3-235b": "qwen3-235b-a22b",
|
||||||
"qwen-3-30b": "qwen3-30b-a3b",
|
"qwen-3-30b": "qwen3-30b-a3b",
|
||||||
"qwen-3-32b": "qwen3-32b",
|
"qwen-3-32b": "qwen3-32b",
|
||||||
"qwen-3-14b": "qwen3-14b",
|
"qwen-3-14b": "qwen3-14b",
|
||||||
|
|
@ -76,12 +76,12 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
'Cache-Control': 'no-cache',
|
'Cache-Control': 'no-cache',
|
||||||
}
|
}
|
||||||
|
|
||||||
sys_prompt = "\n".join([message['content'] for message in messages if message['role'] == 'system'])
|
system_prompt = get_system_prompt(messages)
|
||||||
sys_prompt = sys_prompt if sys_prompt else "You are a helpful and harmless assistant."
|
system_prompt = system_prompt if system_prompt else "You are a helpful and harmless assistant."
|
||||||
|
|
||||||
payload_join = {"data": [
|
payload_join = {"data": [
|
||||||
get_last_user_message(messages),
|
get_last_user_message(messages),
|
||||||
{"thinking_budget": thinking_budget, "model": cls.get_model(model), "sys_prompt": sys_prompt}, None, None],
|
{"thinking_budget": thinking_budget, "model": cls.get_model(model), "sys_prompt": system_prompt}, None, None],
|
||||||
"event_data": None, "fn_index": 13, "trigger_id": 31, "session_hash": conversation.session_hash
|
"event_data": None, "fn_index": 13, "trigger_id": 31, "session_hash": conversation.session_hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
label = "LMArena Beta"
|
label = "LMArena Beta"
|
||||||
url = "https://beta.lmarena.ai"
|
url = "https://beta.lmarena.ai"
|
||||||
api_endpoint = "https://beta.lmarena.ai/api/stream/create-evaluation"
|
api_endpoint = "https://beta.lmarena.ai/api/stream/create-evaluation"
|
||||||
working = True
|
working = has_nodriver
|
||||||
|
|
||||||
default_model = list(text_models.keys())[0]
|
default_model = list(text_models.keys())[0]
|
||||||
models = list(text_models) + list(image_models)
|
models = list(text_models) + list(image_models)
|
||||||
|
|
@ -102,7 +102,8 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
async def callback(page):
|
async def callback(page):
|
||||||
while not await page.evaluate('document.cookie.indexOf("arena-auth-prod-v1") >= 0'):
|
while not await page.evaluate('document.cookie.indexOf("arena-auth-prod-v1") >= 0'):
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
await asyncio.sleep(5)
|
while await page.evaluate('document.querySelector(\'[name="cf-turnstile-response"]\').length > 0') :
|
||||||
|
await asyncio.sleep(1)
|
||||||
args = await get_args_from_nodriver(cls.url, proxy=proxy, callback=callback)
|
args = await get_args_from_nodriver(cls.url, proxy=proxy, callback=callback)
|
||||||
except (RuntimeError, FileNotFoundError) as e:
|
except (RuntimeError, FileNotFoundError) as e:
|
||||||
debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
|
debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
|
||||||
|
|
|
||||||
|
|
@ -257,14 +257,15 @@ class PuterJS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"lfm-7b": "openrouter:liquid/lfm-7b",
|
"lfm-7b": "openrouter:liquid/lfm-7b",
|
||||||
"lfm-3b": "openrouter:liquid/lfm-3b",
|
"lfm-3b": "openrouter:liquid/lfm-3b",
|
||||||
"lfm-40b": "openrouter:liquid/lfm-40b",
|
"lfm-40b": "openrouter:liquid/lfm-40b",
|
||||||
|
}
|
||||||
|
|
||||||
}
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls) -> list[str]:
|
def get_models(cls, api_key: str = None) -> list[str]:
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
try:
|
try:
|
||||||
url = "https://api.puter.com/puterai/chat/models/"
|
url = "https://api.puter.com/puterai/chat/models/"
|
||||||
cls.models = requests.get(url).json().get("models", [])
|
cls.models = requests.get(url).json().get("models", [])
|
||||||
|
cls.models = [model for model in cls.models if model not in ["abuse", "costly", "fake"]]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"PuterJS: Failed to fetch models from API: {e}")
|
debug.log(f"PuterJS: Failed to fetch models from API: {e}")
|
||||||
cls.models = list(cls.model_aliases.keys())
|
cls.models = list(cls.model_aliases.keys())
|
||||||
|
|
|
||||||
|
|
@ -51,14 +51,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
def get_models(cls):
|
def get_models(cls):
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
try:
|
try:
|
||||||
text = requests.get(cls.url).text
|
models = requests.get(f"{cls.url}/api/v2/models").json().get("json")
|
||||||
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
|
|
||||||
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
|
||||||
text = text.replace('void 0', 'null')
|
|
||||||
def add_quotation_mark(match):
|
|
||||||
return f'{match.group(1)}"{match.group(2)}":'
|
|
||||||
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
|
|
||||||
models = json.loads(text)
|
|
||||||
cls.text_models = [model["id"] for model in models]
|
cls.text_models = [model["id"] for model in models]
|
||||||
cls.models = cls.text_models + cls.image_models
|
cls.models = cls.text_models + cls.image_models
|
||||||
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class Api:
|
||||||
def get_model_data(provider: ProviderModelMixin, model: str):
|
def get_model_data(provider: ProviderModelMixin, model: str):
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
"label": model.split(":")[-1] if provider.__name__ == "AnyProvider" else model,
|
"label": model.split(":")[-1] if provider.__name__ == "AnyProvider" and not model.startswith("openrouter:") else model,
|
||||||
"default": model == provider.default_model,
|
"default": model == provider.default_model,
|
||||||
"vision": model in provider.vision_models,
|
"vision": model in provider.vision_models,
|
||||||
"audio": False if provider.audio_models is None else model in provider.audio_models,
|
"audio": False if provider.audio_models is None else model in provider.audio_models,
|
||||||
|
|
|
||||||
|
|
@ -10,22 +10,29 @@ from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
||||||
from ..Provider.hf_space import HuggingSpace
|
from ..Provider.hf_space import HuggingSpace
|
||||||
from ..Provider import __map__
|
from ..Provider import __map__
|
||||||
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
||||||
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena
|
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, OIVSCodeSer2, OIVSCodeSer0501, TeachAnything
|
||||||
|
from ..Provider import Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena, LMArenaBeta
|
||||||
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
|
from ..Provider import EdgeTTS, gTTS, MarkItDown, OpenAIFM
|
||||||
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
|
from ..Provider import HarProvider, HuggingFace, HuggingFaceMedia
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from .. import Provider
|
from .. import Provider
|
||||||
from .. import models
|
from .. import models
|
||||||
|
|
||||||
MAIN_PROVIERS = [
|
PROVIERS_LIST_1 = [
|
||||||
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox,
|
OpenaiChat, PollinationsAI, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
|
||||||
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs, LegacyLMArena,
|
OIVSCodeSer2, OIVSCodeSer0501, TeachAnything, Together, WeWordle, Yqcloud, Chatai, Free2GPT, ARTA, ImageLabs,
|
||||||
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, HuggingFace, HuggingFaceMedia
|
HarProvider, LegacyLMArena, LMArenaBeta, LambdaChat, CopilotAccount, DeepInfraChat,
|
||||||
|
HuggingSpace, HuggingFace, HuggingFaceMedia, PuterJS, Together
|
||||||
]
|
]
|
||||||
|
|
||||||
SPECIAL_PROVIDERS = [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, LegacyLMArena, ARTA]
|
PROVIERS_LIST_2 = [
|
||||||
|
OpenaiChat, CopilotAccount, PollinationsAI, PerplexityLabs, Gemini, Grok, ARTA
|
||||||
|
]
|
||||||
|
|
||||||
SPECIAL_PROVIDERS2 = [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, PuterJS]
|
PROVIERS_LIST_3 = [
|
||||||
|
HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia, LegacyLMArena, LMArenaBeta,
|
||||||
|
PuterJS, Together, Cloudflare, HuggingSpace
|
||||||
|
]
|
||||||
|
|
||||||
LABELS = {
|
LABELS = {
|
||||||
"default": "Default",
|
"default": "Default",
|
||||||
|
|
@ -33,15 +40,21 @@ LABELS = {
|
||||||
"llama": "Meta: LLaMA",
|
"llama": "Meta: LLaMA",
|
||||||
"deepseek": "DeepSeek",
|
"deepseek": "DeepSeek",
|
||||||
"qwen": "Alibaba: Qwen",
|
"qwen": "Alibaba: Qwen",
|
||||||
"google": "Google: Gemini / Gemma / Bard",
|
"google": "Google: Gemini / Gemma",
|
||||||
"grok": "xAI: Grok",
|
"grok": "xAI: Grok",
|
||||||
"claude": "Anthropic: Claude",
|
"claude": "Anthropic: Claude",
|
||||||
"command": "Cohere: Command",
|
"command": "Cohere: Command",
|
||||||
"phi": "Microsoft: Phi / WizardLM",
|
"phi": "Microsoft: Phi / WizardLM",
|
||||||
"mistral": "Mistral",
|
"mistral": "Mistral",
|
||||||
"PollinationsAI": "Pollinations AI",
|
"PollinationsAI": "Pollinations AI",
|
||||||
|
"ARTA": "ARTA",
|
||||||
|
"voices": "Voices",
|
||||||
"perplexity": "Perplexity Labs",
|
"perplexity": "Perplexity Labs",
|
||||||
"openrouter": "OpenRouter",
|
"openrouter": "OpenRouter",
|
||||||
|
"glm": "GLM",
|
||||||
|
"tulu": "Tulu",
|
||||||
|
"reka": "Reka",
|
||||||
|
"hermes": "Hermes",
|
||||||
"video": "Video Generation",
|
"video": "Video Generation",
|
||||||
"image": "Image Generation",
|
"image": "Image Generation",
|
||||||
"other": "Other Models",
|
"other": "Other Models",
|
||||||
|
|
@ -65,31 +78,35 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
continue # Already added
|
continue # Already added
|
||||||
|
|
||||||
added = False
|
added = False
|
||||||
|
# Check for models with prefix
|
||||||
# Check for PollinationsAI models (with prefix)
|
start = model.split(":")[0]
|
||||||
if model.startswith("PollinationsAI:"):
|
if start in ("PollinationsAI", "ARTA", "openrouter"):
|
||||||
groups["PollinationsAI"].append(model)
|
submodel = model.split(":", maxsplit=1)[1]
|
||||||
|
if submodel in OpenAIFM.voices or submodel in PollinationsAI.audio_models[PollinationsAI.default_audio_model]:
|
||||||
|
groups["voices"].append(submodel)
|
||||||
|
else:
|
||||||
|
groups[start].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for Mistral company models specifically
|
# Check for Mistral company models specifically
|
||||||
elif model.startswith("mistral") and not any(x in model for x in ["dolphin", "nous", "openhermes"]):
|
elif model.startswith("mistral") and not any(x in model for x in ["dolphin", "nous", "openhermes"]):
|
||||||
groups["mistral"].append(model)
|
groups["mistral"].append(model)
|
||||||
added = True
|
added = True
|
||||||
elif model.startswith(("mistralai/", "mixtral-", "pixtral-", "ministral-", "codestral-")):
|
elif model.startswith(("pixtral-", "ministral-", "codestral")) or "mistral" in model or "mixtral" in model:
|
||||||
groups["mistral"].append(model)
|
groups["mistral"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for Qwen models
|
# Check for Qwen models
|
||||||
elif model.startswith(("qwen", "Qwen/", "qwq", "qvq")):
|
elif model.startswith(("qwen", "Qwen", "qwq", "qvq")):
|
||||||
groups["qwen"].append(model)
|
groups["qwen"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for Microsoft Phi models
|
# Check for Microsoft Phi models
|
||||||
elif model.startswith(("phi-", "microsoft/")):
|
elif model.startswith(("phi-", "microsoft/")) or "wizardlm" in model.lower():
|
||||||
groups["phi"].append(model)
|
groups["phi"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for Meta LLaMA models
|
# Check for Meta LLaMA models
|
||||||
elif model.startswith(("llama-", "meta-llama/", "llama2-", "llama3")):
|
elif model.startswith(("llama-", "meta-llama/", "llama2-", "llama3")):
|
||||||
groups["llama"].append(model)
|
groups["llama"].append(model)
|
||||||
added = True
|
added = True
|
||||||
elif model == "meta-ai":
|
elif model == "meta-ai" or model.startswith("codellama-"):
|
||||||
groups["llama"].append(model)
|
groups["llama"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for Google models
|
# Check for Google models
|
||||||
|
|
@ -100,14 +117,6 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
elif model.startswith(("command-", "CohereForAI/", "c4ai-command")):
|
elif model.startswith(("command-", "CohereForAI/", "c4ai-command")):
|
||||||
groups["command"].append(model)
|
groups["command"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for Claude models
|
|
||||||
elif model.startswith("claude-"):
|
|
||||||
groups["claude"].append(model)
|
|
||||||
added = True
|
|
||||||
# Check for Grok models
|
|
||||||
elif model.startswith("grok-"):
|
|
||||||
groups["grok"].append(model)
|
|
||||||
added = True
|
|
||||||
# Check for DeepSeek models
|
# Check for DeepSeek models
|
||||||
elif model.startswith(("deepseek-", "janus-")):
|
elif model.startswith(("deepseek-", "janus-")):
|
||||||
groups["deepseek"].append(model)
|
groups["deepseek"].append(model)
|
||||||
|
|
@ -116,34 +125,27 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
elif model.startswith(("sonar", "sonar-", "pplx-")) or model == "r1-1776":
|
elif model.startswith(("sonar", "sonar-", "pplx-")) or model == "r1-1776":
|
||||||
groups["perplexity"].append(model)
|
groups["perplexity"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for OpenAI models
|
# Check for image models - UPDATED to include flux check
|
||||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "dall-e-3", "searchgpt"):
|
elif model in cls.image_models:
|
||||||
groups["openai"].append(model)
|
groups["image"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for openrouter models
|
# Check for OpenAI models
|
||||||
elif model.startswith(("openrouter:")):
|
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"):
|
||||||
groups["openrouter"].append(model)
|
groups["openai"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for video models
|
# Check for video models
|
||||||
elif model in cls.video_models:
|
elif model in cls.video_models:
|
||||||
groups["video"].append(model)
|
groups["video"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for image models - UPDATED to include flux check
|
if not added:
|
||||||
elif model in cls.image_models or "flux" in model.lower() or "stable-diffusion" in model.lower() or "sdxl" in model.lower() or "gpt-image" in model.lower():
|
for group in LABELS.keys():
|
||||||
groups["image"].append(model)
|
if model == group or group in model:
|
||||||
added = True
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
# If not categorized, check for special cases then put in other
|
# If not categorized, check for special cases then put in other
|
||||||
if not added:
|
if not added:
|
||||||
# CodeLlama is Meta's model
|
groups["other"].append(model)
|
||||||
if model.startswith("codellama-"):
|
|
||||||
groups["llama"].append(model)
|
|
||||||
# WizardLM is Microsoft's
|
|
||||||
elif "wizardlm" in model.lower():
|
|
||||||
groups["phi"].append(model)
|
|
||||||
else:
|
|
||||||
groups["other"].append(model)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{"group": LABELS[group], "models": names} for group, names in groups.items()
|
{"group": LABELS[group], "models": names} for group, names in groups.items()
|
||||||
]
|
]
|
||||||
|
|
@ -174,31 +176,19 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
all_models = [cls.default_model] + list(model_with_providers.keys())
|
all_models = [cls.default_model] + list(model_with_providers.keys())
|
||||||
|
|
||||||
# Process special providers
|
# Process special providers
|
||||||
for provider in SPECIAL_PROVIDERS:
|
for provider in PROVIERS_LIST_2:
|
||||||
provider: ProviderType = provider
|
provider: ProviderType = provider
|
||||||
if not provider.working or provider.get_parent() in ignored:
|
if not provider.working or provider.get_parent() in ignored:
|
||||||
continue
|
continue
|
||||||
if provider == CopilotAccount:
|
if provider == CopilotAccount:
|
||||||
all_models.extend(list(provider.model_aliases.keys()))
|
all_models.extend(list(provider.model_aliases.keys()))
|
||||||
elif provider == PollinationsAI:
|
elif provider in [PollinationsAI, ARTA]:
|
||||||
all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models])
|
all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models])
|
||||||
cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models})
|
cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models})
|
||||||
cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models])
|
cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models])
|
||||||
cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models])
|
cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models])
|
||||||
all_models.extend(list(provider.model_aliases.keys()))
|
if provider == PollinationsAI:
|
||||||
elif provider == LegacyLMArena:
|
all_models.extend(list(provider.model_aliases.keys()))
|
||||||
# Add models from LegacyLMArena
|
|
||||||
provider_models = provider.get_models()
|
|
||||||
all_models.extend(provider_models)
|
|
||||||
# Also add model aliases
|
|
||||||
all_models.extend(list(provider.model_aliases.keys()))
|
|
||||||
# Add vision models
|
|
||||||
cls.vision_models.extend(provider.vision_models)
|
|
||||||
elif provider == ARTA:
|
|
||||||
# Add all ARTA models as image models
|
|
||||||
arta_models = provider.get_models()
|
|
||||||
all_models.extend(arta_models)
|
|
||||||
cls.image_models.extend(arta_models)
|
|
||||||
else:
|
else:
|
||||||
all_models.extend(provider.get_models())
|
all_models.extend(provider.get_models())
|
||||||
|
|
||||||
|
|
@ -215,9 +205,9 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
name = name.split("/")[-1].split(":")[0].lower()
|
name = name.split("/")[-1].split(":")[0].lower()
|
||||||
# Date patterns
|
# Date patterns
|
||||||
name = re.sub(r'-\d{4}-\d{2}-\d{2}', '', name)
|
name = re.sub(r'-\d{4}-\d{2}-\d{2}', '', name)
|
||||||
name = re.sub(r'-\d{8}', '', name)
|
name = re.sub(r'-\d{3,8}', '', name)
|
||||||
name = re.sub(r'-\d{4}', '', name)
|
|
||||||
name = re.sub(r'-\d{2}-\d{2}', '', name)
|
name = re.sub(r'-\d{2}-\d{2}', '', name)
|
||||||
|
name = re.sub(r'-[0-9a-f]{8}$', '', name)
|
||||||
# Version patterns
|
# Version patterns
|
||||||
name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf)$', '', name)
|
name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf)$', '', name)
|
||||||
# Other replacements
|
# Other replacements
|
||||||
|
|
@ -226,23 +216,28 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
name = name.replace("meta-llama-", "llama-")
|
name = name.replace("meta-llama-", "llama-")
|
||||||
name = name.replace("llama3", "llama-3")
|
name = name.replace("llama3", "llama-3")
|
||||||
name = name.replace("flux.1-", "flux-")
|
name = name.replace("flux.1-", "flux-")
|
||||||
|
name = name.replace("-free", "")
|
||||||
|
name = name.replace("qwen1-", "qwen-1")
|
||||||
|
name = name.replace("qwen2-", "qwen-2")
|
||||||
|
name = name.replace("qwen3-", "qwen-3")
|
||||||
|
name = name.replace("stable-diffusion-3.5-large", "sd-3.5-large")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
# Process HAR providers
|
# Process HAR providers
|
||||||
for provider in SPECIAL_PROVIDERS2:
|
for provider in PROVIERS_LIST_3:
|
||||||
if not provider.working or provider.get_parent() in ignored:
|
if not provider.working or provider.get_parent() in ignored:
|
||||||
continue
|
continue
|
||||||
new_models = provider.get_models()
|
new_models = provider.get_models()
|
||||||
if provider == HuggingFaceMedia:
|
if provider == HuggingFaceMedia:
|
||||||
new_models = provider.video_models
|
new_models = provider.video_models
|
||||||
|
model_map = {}
|
||||||
# Add original models too, not just cleaned names
|
for model in new_models:
|
||||||
all_models.extend(new_models)
|
clean_value = model if model.startswith("openrouter:") else clean_name(model)
|
||||||
|
if clean_value not in model_map:
|
||||||
model_map = {model if model.startswith("openrouter:") else clean_name(model): model for model in new_models}
|
model_map[clean_value] = model
|
||||||
if not provider.model_aliases:
|
if provider.model_aliases:
|
||||||
provider.model_aliases = {}
|
model_map.update(provider.model_aliases)
|
||||||
provider.model_aliases.update(model_map)
|
provider.model_aliases = model_map
|
||||||
all_models.extend(list(model_map.keys()))
|
all_models.extend(list(model_map.keys()))
|
||||||
|
|
||||||
# Update special model lists with both original and cleaned names
|
# Update special model lists with both original and cleaned names
|
||||||
|
|
@ -262,7 +257,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
cls.audio_models.update(provider.audio_models)
|
cls.audio_models.update(provider.audio_models)
|
||||||
|
|
||||||
# Update model counts
|
# Update model counts
|
||||||
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
for model in all_models:
|
||||||
|
count = all_models.count(model)
|
||||||
|
if count > cls.models_count.get(model, 0):
|
||||||
|
cls.models_count.update({model: count})
|
||||||
|
|
||||||
# Deduplicate and store
|
# Deduplicate and store
|
||||||
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
||||||
|
|
@ -320,11 +318,15 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if isinstance(api_key, dict):
|
if isinstance(api_key, dict):
|
||||||
for provider in api_key:
|
for provider in api_key:
|
||||||
if api_key.get(provider):
|
if api_key.get(provider):
|
||||||
if provider in __map__ and __map__[provider] not in MAIN_PROVIERS:
|
if provider in __map__ and __map__[provider] not in PROVIERS_LIST_1:
|
||||||
extra_providers.append(__map__[provider])
|
extra_providers.append(__map__[provider])
|
||||||
for provider in MAIN_PROVIERS + extra_providers:
|
for provider in PROVIERS_LIST_1 + extra_providers:
|
||||||
if provider.working:
|
if provider.working:
|
||||||
if not model or model in provider.get_models() or model in provider.model_aliases:
|
provider_api_key = api_key
|
||||||
|
if isinstance(api_key, dict):
|
||||||
|
provider_api_key = api_key.get(provider.get_parent())
|
||||||
|
provider_models = provider.get_models(api_key=provider_api_key) if provider_api_key else provider.get_models()
|
||||||
|
if not model or model in provider_models or provider.model_aliases and model in provider.model_aliases:
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
if model in models.__models__:
|
if model in models.__models__:
|
||||||
for provider in models.__models__[model][1]:
|
for provider in models.__models__[model][1]:
|
||||||
|
|
@ -334,7 +336,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
|
||||||
if len(providers) == 0:
|
if len(providers) == 0:
|
||||||
raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.")
|
raise ModelNotFoundError(f"AnyProvider: Model {model} not found in any provider.")
|
||||||
|
|
||||||
async for chunk in IterListProvider(providers).create_async_generator(
|
async for chunk in IterListProvider(providers).create_async_generator(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
|
|
||||||
|
|
@ -367,20 +367,19 @@ class ProviderModelMixin:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls, **kwargs) -> list[str]:
|
def get_models(cls, **kwargs) -> list[str]:
|
||||||
if not cls.models and cls.default_model is not None:
|
if not cls.models and cls.default_model is not None:
|
||||||
return [cls.default_model]
|
cls.models = [cls.default_model]
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str, **kwargs) -> str:
|
def get_model(cls, model: str, **kwargs) -> str:
|
||||||
if not model and cls.default_model is not None:
|
if not model and cls.default_model is not None:
|
||||||
model = cls.default_model
|
model = cls.default_model
|
||||||
elif model in cls.model_aliases:
|
if model in cls.model_aliases:
|
||||||
model = cls.model_aliases[model]
|
model = cls.model_aliases[model]
|
||||||
else:
|
else:
|
||||||
if model not in cls.get_models(**kwargs) and cls.models:
|
if model not in cls.get_models(**kwargs) and cls.models:
|
||||||
raise ModelNotFoundError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
|
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||||
cls.last_model = model
|
cls.last_model = model
|
||||||
debug.last_model = model
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
class RaiseErrorMixin():
|
class RaiseErrorMixin():
|
||||||
|
|
|
||||||
|
|
@ -55,16 +55,16 @@ class IterListProvider(BaseRetryProvider):
|
||||||
self.last_provider = provider
|
self.last_provider = provider
|
||||||
if not model:
|
if not model:
|
||||||
model = getattr(provider, "default_model", None)
|
model = getattr(provider, "default_model", None)
|
||||||
model = provider.model_aliases.get(model, model) if hasattr(provider, "model_aliases") else model
|
alias = provider.model_aliases.get(model, model) if hasattr(provider, "model_aliases") else model
|
||||||
debug.log(f"Using {provider.__name__} provider with model {model}")
|
debug.log(f"Using {provider.__name__} provider with model {alias}")
|
||||||
yield ProviderInfo(**provider.get_dict(), model=model)
|
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||||
extra_body = kwargs.copy()
|
extra_body = kwargs.copy()
|
||||||
if isinstance(api_key, dict):
|
if isinstance(api_key, dict):
|
||||||
api_key = api_key.get(provider.get_parent())
|
api_key = api_key.get(provider.get_parent())
|
||||||
if api_key:
|
if api_key:
|
||||||
extra_body["api_key"] = api_key
|
extra_body["api_key"] = api_key
|
||||||
try:
|
try:
|
||||||
response = provider.create_function(model, messages, stream=stream, **extra_body)
|
response = provider.create_function(alias, messages, stream=stream, **extra_body)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk:
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,14 @@ def get_github_version(repo: str) -> str:
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
||||||
|
|
||||||
|
def get_git_version() -> str:
|
||||||
|
# Read from git repository
|
||||||
|
try:
|
||||||
|
command = ["git", "describe", "--tags", "--abbrev=0"]
|
||||||
|
return check_output(command, text=True, stderr=PIPE).strip()
|
||||||
|
except CalledProcessError:
|
||||||
|
return None
|
||||||
|
|
||||||
class VersionUtils:
|
class VersionUtils:
|
||||||
"""
|
"""
|
||||||
Utility class for managing and comparing package versions of 'g4f'.
|
Utility class for managing and comparing package versions of 'g4f'.
|
||||||
|
|
@ -78,14 +86,7 @@ class VersionUtils:
|
||||||
if version:
|
if version:
|
||||||
return version
|
return version
|
||||||
|
|
||||||
# Read from git repository
|
return get_git_version()
|
||||||
try:
|
|
||||||
command = ["git", "describe", "--tags", "--abbrev=0"]
|
|
||||||
return check_output(command, text=True, stderr=PIPE).strip()
|
|
||||||
except CalledProcessError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latest_version(self) -> str:
|
def latest_version(self) -> str:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue