Improve tools support in OpenaiTemplate and GeminiPro

Update models in DDG, PerplexityLabs, Gemini
Fix issues with curl_cffi in new versions
This commit is contained in:
hlohaus 2025-02-21 04:36:54 +01:00
parent c3ed6d0f8f
commit e53483d85b
33 changed files with 300 additions and 172 deletions

1
.gitignore vendored
View file

@ -66,3 +66,4 @@ bench.py
to-reverse.txt to-reverse.txt
g4f/Provider/OpenaiChat2.py g4f/Provider/OpenaiChat2.py
generated_images/ generated_images/
projects/windows/

View file

@ -1,5 +1,9 @@
import unittest import unittest
import g4f.debug
g4f.debug.version_check = False
from .asyncio import * from .asyncio import *
from .backend import * from .backend import *
from .main import * from .main import *

View file

@ -6,8 +6,12 @@ from g4f.errors import VersionNotFoundError
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestGetLastProvider(unittest.TestCase): class TestGetLastProvider(unittest.TestCase):
def test_get_latest_version(self): def test_get_latest_version(self):
current_version = g4f.version.utils.current_version current_version = g4f.version.utils.current_version
if current_version is not None: if current_version is not None:
self.assertIsInstance(g4f.version.utils.current_version, str) self.assertIsInstance(g4f.version.utils.current_version, str)
self.assertIsInstance(g4f.version.utils.latest_version, str) try:
self.assertIsInstance(g4f.version.utils.latest_version, str)
except VersionNotFoundError:
pass

View file

@ -42,7 +42,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
"gpt-4": "gpt-4o-mini", "gpt-4": "gpt-4o-mini",
"claude-3-haiku": "claude-3-haiku-20240307", "claude-3-haiku": "claude-3-haiku-20240307",
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo", "llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x7b": "mistralai/Mistral-Small-24B-Instruct-2501",
} }
last_request_time = 0 last_request_time = 0

View file

@ -5,7 +5,8 @@ import json
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from ..requests import StreamSession, raise_for_status from ..requests import StreamSession, raise_for_status
from ..providers.response import FinishReason from ..errors import ResponseError
from ..providers.response import FinishReason, Sources
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
API_URL = "https://www.perplexity.ai/socket.io/" API_URL = "https://www.perplexity.ai/socket.io/"
@ -15,10 +16,11 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://labs.perplexity.ai" url = "https://labs.perplexity.ai"
working = True working = True
default_model = "sonar-pro" default_model = "r1-1776"
models = [ models = [
"sonar",
default_model, default_model,
"sonar-pro",
"sonar",
"sonar-reasoning", "sonar-reasoning",
"sonar-reasoning-pro", "sonar-reasoning-pro",
] ]
@ -32,19 +34,10 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
headers = { headers = {
"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0",
"Accept": "*/*",
"Accept-Language": "de,en-US;q=0.7,en;q=0.3",
"Accept-Encoding": "gzip, deflate, br",
"Origin": cls.url, "Origin": cls.url,
"Connection": "keep-alive",
"Referer": f"{cls.url}/", "Referer": f"{cls.url}/",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "same-site",
"TE": "trailers",
} }
async with StreamSession(headers=headers, proxies={"all": proxy}) as session: async with StreamSession(headers=headers, proxy=proxy, impersonate="chrome") as session:
t = format(random.getrandbits(32), "08x") t = format(random.getrandbits(32), "08x")
async with session.get( async with session.get(
f"{API_URL}?EIO=4&transport=polling&t={t}" f"{API_URL}?EIO=4&transport=polling&t={t}"
@ -60,17 +53,22 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
) as response: ) as response:
await raise_for_status(response) await raise_for_status(response)
assert await response.text() == "OK" assert await response.text() == "OK"
async with session.get(
f"{API_URL}?EIO=4&transport=polling&t={t}&sid={sid}",
data=post_data
) as response:
await raise_for_status(response)
assert (await response.text()).startswith("40")
async with session.ws_connect(f"{WS_URL}?EIO=4&transport=websocket&sid={sid}", autoping=False) as ws: async with session.ws_connect(f"{WS_URL}?EIO=4&transport=websocket&sid={sid}", autoping=False) as ws:
await ws.send_str("2probe") await ws.send_str("2probe")
assert(await ws.receive_str() == "3probe") assert(await ws.receive_str() == "3probe")
await ws.send_str("5") await ws.send_str("5")
assert(await ws.receive_str())
assert(await ws.receive_str() == "6") assert(await ws.receive_str() == "6")
message_data = { message_data = {
"version": "2.16", "version": "2.18",
"source": "default", "source": "default",
"model": model, "model": model,
"messages": messages "messages": messages,
} }
await ws.send_str("42" + json.dumps(["perplexity_labs", message_data])) await ws.send_str("42" + json.dumps(["perplexity_labs", message_data]))
last_message = 0 last_message = 0
@ -82,12 +80,15 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
await ws.send_str("3") await ws.send_str("3")
continue continue
try: try:
if last_message == 0 and model == cls.default_model:
yield "<think>"
data = json.loads(message[2:])[1] data = json.loads(message[2:])[1]
yield data["output"][last_message:] yield data["output"][last_message:]
last_message = len(data["output"]) last_message = len(data["output"])
if data["final"]: if data["final"]:
if data["citations"]:
yield Sources(data["citations"])
yield FinishReason("stop") yield FinishReason("stop")
break break
except Exception as e: except Exception as e:
print(f"Error processing message: {message} - {e}") raise ResponseError(f"Message: {message}") from e
raise RuntimeError(f"Message: {message}") from e

View file

@ -122,9 +122,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
except ModelNotFoundError: except ModelNotFoundError:
if model not in cls.image_models: if model not in cls.image_models:
raise raise
if not cache and seed is None: if not cache and seed is None:
seed = random.randint(0, 10000) seed = random.randint(1000, 999999)
if model in cls.image_models: if model in cls.image_models:
async for chunk in cls._generate_image( async for chunk in cls._generate_image(
@ -182,9 +182,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
} }
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items()) query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}" prefix = f"{model}_{seed}" if seed is not None else model
url = f"{cls.image_api_endpoint}prompt/{prefix}_{quote_plus(prompt)}?{query}"
yield ImagePreview(url, prompt) yield ImagePreview(url, prompt)
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session: async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with session.get(url, allow_redirects=True) as response: async with session.get(url, allow_redirects=True) as response:
await raise_for_status(response) await raise_for_status(response)

View file

@ -39,14 +39,15 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
default_model = default_model default_model = default_model
model_aliases = model_aliases model_aliases = model_aliases
image_models = image_models image_models = image_models
text_models = fallback_models
@classmethod @classmethod
def get_models(cls): def get_models(cls):
if not cls.models: if not cls.models:
try: try:
text = requests.get(cls.url).text text = requests.get(cls.url).text
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1) text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = text.replace('void 0', 'null') text = text.replace('void 0', 'null')
def add_quotation_mark(match): def add_quotation_mark(match):
return f'{match.group(1)}"{match.group(2)}":' return f'{match.group(1)}"{match.group(2)}":'
@ -56,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
cls.models = cls.text_models + cls.image_models cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]] cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e: except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}") debug.error(f"{cls.__name__}: Error reading models: {type(e).__name__}: {e}")
cls.models = [*fallback_models] cls.models = [*fallback_models]
return cls.models return cls.models

View file

@ -10,8 +10,8 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse from ...providers.response import FinishReason, ImageResponse
from ..helper import format_image_prompt from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, fallback_models from .models import default_model, default_image_model, model_aliases, fallback_models, image_models
from ... import debug from ... import debug
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin): class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
@ -22,26 +22,26 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
default_model = default_model default_model = default_model
default_image_model = default_image_model default_image_model = default_image_model
model_aliases = model_aliases model_aliases = model_aliases
image_models = image_models
@classmethod @classmethod
def get_models(cls) -> list[str]: def get_models(cls) -> list[str]:
if not cls.models: if not cls.models:
models = fallback_models.copy() models = fallback_models.copy()
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation" url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
response = requests.get(url) response = requests.get(url)
response.raise_for_status() if response.ok:
extra_models = [model["id"] for model in response.json()] extra_models = [model["id"] for model in response.json()]
extra_models.sort() extra_models.sort()
models.extend([model for model in extra_models if model not in models]) models.extend([model for model in extra_models if model not in models])
if not cls.image_models: url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image" response = requests.get(url)
response = requests.get(url) if response.ok:
response.raise_for_status() cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20] cls.image_models.sort()
cls.image_models.sort() models.extend([model for model in cls.image_models if model not in models])
models.extend([model for model in cls.image_models if model not in models]) cls.models = models
cls.models = models return cls.models
return cls.models
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@ -57,6 +57,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str = None, prompt: str = None,
action: str = None, action: str = None,
extra_data: dict = {}, extra_data: dict = {},
seed: int = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
try: try:
@ -104,7 +105,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
if pipeline_tag == "text-to-image": if pipeline_tag == "text-to-image":
stream = False stream = False
inputs = format_image_prompt(messages, prompt) inputs = format_image_prompt(messages, prompt)
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32), **extra_data}} payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **extra_data}}
elif pipeline_tag in ("text-generation", "image-text-to-text"): elif pipeline_tag in ("text-generation", "image-text-to-text"):
model_type = None model_type = None
if "config" in model_data and "model_type" in model_data["config"]: if "config" in model_data and "model_type" in model_data["config"]:
@ -116,11 +117,13 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
if len(messages) > 6: if len(messages) > 6:
messages = messages[:3] + messages[-3:] messages = messages[:3] + messages[-3:]
else: else:
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]] messages = [m for m in messages if m["role"] == "system"] + [get_last_user_message(messages)]
inputs = get_inputs(messages, model_data, model_type, do_continue) inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}") debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_tokens >= 1024: if model_type == "gpt2" and max_tokens >= 1024:
params["max_new_tokens"] = 512 params["max_new_tokens"] = 512
if seed is not None:
params["seed"] = seed
payload = {"inputs": inputs, "parameters": params, "stream": stream} payload = {"inputs": inputs, "parameters": params, "stream": stream}
else: else:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")

View file

@ -48,7 +48,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
except Exception as e: except Exception as e:
if is_started: if is_started:
raise e raise e
debug.log(f"Inference failed: {e.__class__.__name__}: {e}") debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
if not cls.image_models: if not cls.image_models:
cls.get_models() cls.get_models()
if model in cls.image_models: if model in cls.image_models:

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
import uuid import uuid
import re import re
import time import random
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
import urllib.parse import urllib.parse
@ -88,7 +88,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
prompt = format_image_prompt(messages, prompt) prompt = format_image_prompt(messages, prompt)
if seed is None: if seed is None:
seed = int(time.time()) seed = random.randint(1000, 999999)
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash") session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
async with StreamSession(proxy=proxy, impersonate="chrome") as session: async with StreamSession(proxy=proxy, impersonate="chrome") as session:

View file

@ -32,9 +32,7 @@ class CopilotAccount(AsyncAuthedProvider, Copilot):
except NoValidHarFileError as h: except NoValidHarFileError as h:
debug.log(f"Copilot: {h}") debug.log(f"Copilot: {h}")
if has_nodriver: if has_nodriver:
login_url = os.environ.get("G4F_LOGIN_URL") yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
if login_url:
yield RequestLogin(cls.label, login_url)
Copilot._access_token, Copilot._cookies = await get_access_token_and_cookies(cls.url, proxy) Copilot._access_token, Copilot._cookies = await get_access_token_and_cookies(cls.url, proxy)
else: else:
raise h raise h

View file

@ -65,7 +65,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
default_image_model = default_model default_image_model = default_model
default_vision_model = default_model default_vision_model = default_model
image_models = [default_image_model] image_models = [default_image_model]
models = [default_model, "gemini-1.5-flash", "gemini-1.5-pro"] models = [default_model, "gemini-2.0"]
synthesize_content_type = "audio/vnd.wav" synthesize_content_type = "audio/vnd.wav"
@ -179,7 +179,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0]) yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
content = response_part[4][0][1][0] content = response_part[4][0][1][0]
except (ValueError, KeyError, TypeError, IndexError) as e: except (ValueError, KeyError, TypeError, IndexError) as e:
debug.log(f"{cls.__name__}:{e.__class__.__name__}:{e}") debug.error(f"{cls.__name__} {type(e).__name__}: {e}")
continue continue
match = re.search(r'\[Imagen of (.*?)\]', content) match = re.search(r'\[Imagen of (.*?)\]', content)
if match: if match:

View file

@ -51,7 +51,9 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
] ]
cls.models.sort() cls.models.sort()
except Exception as e: except Exception as e:
debug.log(e) debug.error(e)
if api_key is not None:
raise MissingAuthError("Invalid API key")
return cls.fallback_models return cls.fallback_models
return cls.models return cls.models
@ -111,8 +113,18 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
"topK": kwargs.get("top_k"), "topK": kwargs.get("top_k"),
}, },
"tools": [{ "tools": [{
"functionDeclarations": tools "function_declarations": [{
}] if tools else None "name": tool["function"]["name"],
"description": tool["function"]["description"],
"parameters": {
"type": "object",
"properties": {key: {
"type": value["type"],
"description": value["title"]
} for key, value in tool["function"]["parameters"]["properties"].items()}
},
} for tool in tools]
}] if tools else None
} }
system_prompt = "\n".join( system_prompt = "\n".join(
message["content"] message["content"]

View file

@ -322,8 +322,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
try: try:
image_requests = await cls.upload_images(session, auth_result, images) if images else None image_requests = await cls.upload_images(session, auth_result, images) if images else None
except Exception as e: except Exception as e:
debug.log("OpenaiChat: Upload image failed") debug.error("OpenaiChat: Upload image failed")
debug.log(f"{e.__class__.__name__}: {e}") debug.error(e)
model = cls.get_model(model) model = cls.get_model(model)
if conversation is None: if conversation is None:
conversation = Conversation(conversation_id, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did")) conversation = Conversation(conversation_id, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
@ -360,12 +360,14 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
# if auth_result.arkose_token is None: # if auth_result.arkose_token is None:
# raise MissingAuthError("No arkose token found in .har file") # raise MissingAuthError("No arkose token found in .har file")
if "proofofwork" in chat_requirements: if "proofofwork" in chat_requirements:
if getattr(auth_result, "proof_token", None) is None: user_agent = getattr(auth_result, "headers", {}).get("user-agent")
auth_result.proof_token = get_config(auth_result.headers.get("user-agent")) proof_token = getattr(auth_result, "proof_token", None)
if proof_token is None:
auth_result.proof_token = get_config(user_agent)
proofofwork = generate_proof_token( proofofwork = generate_proof_token(
**chat_requirements["proofofwork"], **chat_requirements["proofofwork"],
user_agent=getattr(auth_result, "headers", {}).get("user-agent"), user_agent=user_agent,
proof_token=getattr(auth_result, "proof_token", None) proof_token=proof_token
) )
[debug.log(text) for text in ( [debug.log(text) for text in (
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}", #f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
@ -425,8 +427,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
) as response: ) as response:
cls._update_request_args(auth_result, session) cls._update_request_args(auth_result, session)
if response.status == 403: if response.status == 403:
auth_result.proof_token = None
cls.request_config.proof_token = None cls.request_config.proof_token = None
raise MissingAuthError("Access token is not valid")
await raise_for_status(response) await raise_for_status(response)
buffer = u"" buffer = u""
async for line in response.iter_lines(): async for line in response.iter_lines():
@ -469,14 +471,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await asyncio.sleep(5) await asyncio.sleep(5)
else: else:
break break
yield Parameters(**{
"action": "continue" if conversation.finish_reason == "max_tokens" else "variant",
"conversation": conversation.get_dict(),
"proof_token": cls.request_config.proof_token,
"cookies": cls._cookies,
"headers": cls._headers,
"web_search": web_search,
})
yield FinishReason(conversation.finish_reason) yield FinishReason(conversation.finish_reason)
@classmethod @classmethod

View file

@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
if cls.sort_models: if cls.sort_models:
cls.models.sort() cls.models.sort()
except Exception as e: except Exception as e:
debug.log(e) debug.error(e)
return cls.fallback_models return cls.fallback_models
return cls.models return cls.models
@ -65,7 +65,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
prompt: str = None, prompt: str = None,
headers: dict = None, headers: dict = None,
impersonate: str = None, impersonate: str = None,
tools: Optional[list] = None, extra_parameters: list[str] = ["tools", "parallel_tool_calls", "", "reasoning_effort", "logit_bias"],
extra_data: dict = {}, extra_data: dict = {},
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
@ -112,6 +112,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
} }
] ]
messages[-1] = last_message messages[-1] = last_message
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
data = filter_none( data = filter_none(
messages=messages, messages=messages,
model=model, model=model,
@ -120,7 +121,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
top_p=top_p, top_p=top_p,
stop=stop, stop=stop,
stream=stream, stream=stream,
tools=tools, **extra_parameters,
**extra_data **extra_data
) )
if api_endpoint is None: if api_endpoint is None:

View file

@ -588,7 +588,7 @@ class Api:
target=target) target=target)
debug.log(f"Image copied from {source_url}") debug.log(f"Image copied from {source_url}")
except Exception as e: except Exception as e:
debug.log(f"{type(e).__name__}: Download failed: {source_url}\n{e}") debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
return RedirectResponse(url=source_url) return RedirectResponse(url=source_url)
if not os.path.isfile(target): if not os.path.isfile(target):
return ErrorResponse.from_message("File not found", HTTP_404_NOT_FOUND) return ErrorResponse.from_message("File not found", HTTP_404_NOT_FOUND)

View file

@ -18,7 +18,7 @@ from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
from .image_models import ImageModels from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider from .service import get_model_and_provider, convert_to_provider
@ -103,14 +103,14 @@ def iter_response(
idx += 1 idx += 1
if usage is None: if usage is None:
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx) usage = Usage(completion_tokens=idx, total_tokens=idx)
finish_reason = "stop" if finish_reason is None else finish_reason finish_reason = "stop" if finish_reason is None else finish_reason
if stream: if stream:
yield ChatCompletionChunk.model_construct( yield ChatCompletionChunk.model_construct(
None, finish_reason, completion_id, int(time.time()), None, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict() usage=usage
) )
else: else:
if response_format is not None and "type" in response_format: if response_format is not None and "type" in response_format:
@ -118,7 +118,8 @@ def iter_response(
content = filter_json(content) content = filter_json(content)
yield ChatCompletion.model_construct( yield ChatCompletion.model_construct(
content, finish_reason, completion_id, int(time.time()), content, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls) usage=UsageModel.model_construct(**usage.get_dict()),
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
) )
# Synchronous iter_append_model_and_provider function # Synchronous iter_append_model_and_provider function
@ -186,7 +187,7 @@ async def async_iter_response(
finish_reason = "stop" if finish_reason is None else finish_reason finish_reason = "stop" if finish_reason is None else finish_reason
if usage is None: if usage is None:
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx) usage = Usage(completion_tokens=idx, total_tokens=idx)
if stream: if stream:
yield ChatCompletionChunk.model_construct( yield ChatCompletionChunk.model_construct(
@ -199,7 +200,8 @@ async def async_iter_response(
content = filter_json(content) content = filter_json(content)
yield ChatCompletion.model_construct( yield ChatCompletion.model_construct(
content, finish_reason, completion_id, int(time.time()), content, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls) usage=UsageModel.model_construct(**usage.get_dict()),
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
) )
finally: finally:
await safe_aclose(response) await safe_aclose(response)
@ -363,7 +365,7 @@ class Images:
break break
except Exception as e: except Exception as e:
error = e error = e
debug.log(f"Image provider {provider.__name__}: {e}") debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
@ -458,7 +460,7 @@ class Images:
break break
except Exception as e: except Exception as e:
error = e error = e
debug.log(f"Image provider {provider.__name__}: {e}") debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
@ -583,7 +585,7 @@ class AsyncCompletions:
messages: Messages, messages: Messages,
model: str, model: str,
**kwargs **kwargs
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]: ) -> AsyncIterator[ChatCompletionChunk]:
return self.create(messages, model, stream=True, **kwargs) return self.create(messages, model, stream=True, **kwargs)
class AsyncImages(Images): class AsyncImages(Images):

View file

@ -5,9 +5,6 @@ from time import time
from .helper import filter_none from .helper import filter_none
ToolCalls = Optional[List[Dict[str, Any]]]
Usage = Optional[Dict[str, int]]
try: try:
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
except ImportError: except ImportError:
@ -29,6 +26,40 @@ class BaseModel(BaseModel):
return super().model_construct(**data) return super().model_construct(**data)
return cls.construct(**data) return cls.construct(**data)
class UsageModel(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]]
completion_tokens_details: Optional[Dict[str, Any]]
@classmethod
def model_construct(cls, prompt_tokens=0, completion_tokens=0, total_tokens=0, prompt_tokens_details=None, completion_tokens_details=None, **kwargs):
return super().model_construct(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
**kwargs
)
class ToolFunctionModel(BaseModel):
name: str
arguments: str
class ToolCallModel(BaseModel):
id: str
type: str
function: ToolFunctionModel
@classmethod
def model_construct(cls, function=None, **kwargs):
return super().model_construct(
**kwargs,
function=ToolFunctionModel.model_construct(**function),
)
class ChatCompletionChunk(BaseModel): class ChatCompletionChunk(BaseModel):
id: str id: str
object: str object: str
@ -36,7 +67,7 @@ class ChatCompletionChunk(BaseModel):
model: str model: str
provider: Optional[str] provider: Optional[str]
choices: List[ChatCompletionDeltaChoice] choices: List[ChatCompletionDeltaChoice]
usage: Usage usage: UsageModel
@classmethod @classmethod
def model_construct( def model_construct(
@ -45,7 +76,7 @@ class ChatCompletionChunk(BaseModel):
finish_reason: str, finish_reason: str,
completion_id: str = None, completion_id: str = None,
created: int = None, created: int = None,
usage: Usage = None usage: UsageModel = None
): ):
return super().model_construct( return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None, id=f"chatcmpl-{completion_id}" if completion_id else None,
@ -63,10 +94,10 @@ class ChatCompletionChunk(BaseModel):
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: str role: str
content: str content: str
tool_calls: ToolCalls tool_calls: list[ToolCallModel] = None
@classmethod @classmethod
def model_construct(cls, content: str, tool_calls: ToolCalls = None): def model_construct(cls, content: str, tool_calls: list = None):
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls)) return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
class ChatCompletionChoice(BaseModel): class ChatCompletionChoice(BaseModel):
@ -85,11 +116,7 @@ class ChatCompletion(BaseModel):
model: str model: str
provider: Optional[str] provider: Optional[str]
choices: List[ChatCompletionChoice] choices: List[ChatCompletionChoice]
usage: Usage = Field(default={ usage: UsageModel
"prompt_tokens": 0, #prompt_tokens,
"completion_tokens": 0, #completion_tokens,
"total_tokens": 0, #prompt_tokens + completion_tokens,
})
@classmethod @classmethod
def model_construct( def model_construct(
@ -98,8 +125,8 @@ class ChatCompletion(BaseModel):
finish_reason: str, finish_reason: str,
completion_id: str = None, completion_id: str = None,
created: int = None, created: int = None,
tool_calls: ToolCalls = None, tool_calls: list[ToolCallModel] = None,
usage: Usage = None usage: UsageModel = None
): ):
return super().model_construct( return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None, id=f"chatcmpl-{completion_id}" if completion_id else None,

View file

@ -1,3 +1,4 @@
import sys
from .providers.types import ProviderType from .providers.types import ProviderType
logging: bool = False logging: bool = False
@ -8,6 +9,9 @@ version: str = None
log_handler: callable = print log_handler: callable = print
logs: list = [] logs: list = []
def log(text): def log(text, file = None):
if logging: if logging:
log_handler(text) log_handler(text, file=file)
def error(error, name: str = None):
log(error if isinstance(error, str) else f"{type(error).__name__ if name is None else name}: {error}", file=sys.stderr)

View file

@ -181,7 +181,12 @@
<script> <script>
(async () => { (async () => {
const isIframe = window.self !== window.top; const isIframe = window.self !== window.top;
const backendUrl = "{{backend_url}}";
let url = new URL(window.location.href) let url = new URL(window.location.href)
if (isIframe && backendUrl) {
window.location.replace(url.search ? `${backendUrl}?${url.search}` : backendUrl);
return;
}
let params = new URLSearchParams(url.search); let params = new URLSearchParams(url.search);
if (params.get("__sign")) { if (params.get("__sign")) {
localStorage.setItem("zerogpu_token", params.get("__sign")); localStorage.setItem("zerogpu_token", params.get("__sign"));
@ -232,12 +237,11 @@
import * as hub from "@huggingface/hub"; import * as hub from "@huggingface/hub";
import { init } from "@huggingface/space-header"; import { init } from "@huggingface/space-header";
const isIframe = window.self !== window.top;
const button = document.querySelector('form a.button'); const button = document.querySelector('form a.button');
if (isIframe) { if (isIframe) {
button.classList.remove('hidden'); button.classList.remove('hidden');
} else { } else {
init("roxky/g4f-space"); init("roxky/g4f-space-new");
} }
const form = document.querySelector("form"); const form = document.querySelector("form");
@ -282,13 +286,11 @@
const cache_id = Math.floor(Math.random() * max); const cache_id = Math.floor(Math.random() * max);
let prompt; let prompt;
if (cache_id % 2 == 0) { if (cache_id % 2 == 0) {
prompt = ` prompt = `Today is ${new Date().toJSON().slice(0, 10)}.
Today is ${new Date().toJSON().slice(0, 10)}.
Create a single-page HTML screensaver reflecting the current season (based on the date). Create a single-page HTML screensaver reflecting the current season (based on the date).
Avoid using any text.`; Avoid using any text.`;
} else { } else {
prompt = `Create a single-page HTML screensaver. Avoid using any text.`; prompt = `Create a single-page HTML screensaver. Avoid using any text.`;
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
} }
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`); const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
const text = await response.text() const text = await response.text()

View file

@ -293,7 +293,7 @@
</div> </div>
<div class="field"> <div class="field">
<select name="model" id="model"> <select name="model" id="model">
<option value="">Model: Default</option> <option value="" selected="selected">Model: Default</option>
<option value="gpt-4">gpt-4</option> <option value="gpt-4">gpt-4</option>
<option value="gpt-4o">gpt-4o</option> <option value="gpt-4o">gpt-4o</option>
<option value="gpt-4o-mini">gpt-4o-mini</option> <option value="gpt-4o-mini">gpt-4o-mini</option>

View file

@ -72,13 +72,16 @@ class Api:
@staticmethod @staticmethod
def get_version() -> dict: def get_version() -> dict:
current_version = None
latest_version = None
try: try:
current_version = version.utils.current_version current_version = version.utils.current_version
latest_version = version.utils.latest_version
except VersionNotFoundError: except VersionNotFoundError:
current_version = None pass
return { return {
"version": current_version, "version": current_version,
"latest_version": version.utils.latest_version, "latest_version": latest_version,
} }
def serve_images(self, name): def serve_images(self, name):
@ -137,10 +140,10 @@ class Api:
} }
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator: def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
def decorated_log(text: str): def decorated_log(text: str, file = None):
debug.logs.append(text) debug.logs.append(text)
if debug.logging: if debug.logging:
debug.log_handler(text) debug.log_handler(text, file)
debug.log = decorated_log debug.log = decorated_log
proxy = os.environ.get("G4F_PROXY") proxy = os.environ.get("G4F_PROXY")
provider = kwargs.get("provider") provider = kwargs.get("provider")
@ -154,6 +157,7 @@ class Api:
) )
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
debug.error(e)
yield self._format_json('error', type(e).__name__, message=get_error_message(e)) yield self._format_json('error', type(e).__name__, message=get_error_message(e))
return return
if not isinstance(provider_handler, BaseRetryProvider): if not isinstance(provider_handler, BaseRetryProvider):
@ -183,6 +187,7 @@ class Api:
yield self._format_json("conversation_id", conversation_id) yield self._format_json("conversation_id", conversation_id)
elif isinstance(chunk, Exception): elif isinstance(chunk, Exception):
logger.exception(chunk) logger.exception(chunk)
debug.error(e)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__) yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, PreviewResponse): elif isinstance(chunk, PreviewResponse):
yield self._format_json("preview", chunk.to_string()) yield self._format_json("preview", chunk.to_string())
@ -215,20 +220,19 @@ class Api:
yield self._format_json(chunk.type, **chunk.get_dict()) yield self._format_json(chunk.type, **chunk.get_dict())
else: else:
yield self._format_json("content", str(chunk)) yield self._format_json("content", str(chunk))
if debug.logs: yield from self._yield_logs()
for log in debug.logs:
yield self._format_json("log", str(log))
debug.logs = []
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
if debug.logging: debug.error(e)
debug.log_handler(get_error_message(e)) yield from self._yield_logs()
if debug.logs:
for log in debug.logs:
yield self._format_json("log", str(log))
debug.logs = []
yield self._format_json('error', type(e).__name__, message=get_error_message(e)) yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _yield_logs(self):
if debug.logs:
for log in debug.logs:
yield self._format_json("log", log)
debug.logs = []
def _format_json(self, response_type: str, content = None, **kwargs): def _format_json(self, response_type: str, content = None, **kwargs):
if content is not None and isinstance(response_type, str): if content is not None and isinstance(response_type, str):
return { return {

View file

@ -76,7 +76,7 @@ class Backend_Api(Api):
@app.route('/', methods=['GET']) @app.route('/', methods=['GET'])
@limiter.exempt @limiter.exempt
def home(): def home():
return render_template('demo.html') return render_template('demo.html', backend_url=os.environ.get("G4F_BACKEND_URL", ""))
else: else:
@app.route('/', methods=['GET']) @app.route('/', methods=['GET'])
def home(): def home():

View file

@ -123,7 +123,7 @@ async def copy_images(
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}" return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
except (ClientError, IOError, OSError) as e: except (ClientError, IOError, OSError) as e:
debug.log(f"Image processing failed: {e.__class__.__name__}: {e}") debug.error(f"Image processing failed: {type(e).__name__}: {e}")
if target_path and os.path.exists(target_path): if target_path and os.path.exists(target_path):
os.unlink(target_path) os.unlink(target_path)
return get_source_url(image, image) return get_source_url(image, image)

View file

@ -243,7 +243,7 @@ llama_3_3_70b = Model(
mixtral_8x7b = Model( mixtral_8x7b = Model(
name = "mixtral-8x7b", name = "mixtral-8x7b",
base_provider = "Mistral", base_provider = "Mistral",
best_provider = IterListProvider([DDG, Jmuz]) best_provider = Jmuz
) )
mixtral_8x22b = Model( mixtral_8x22b = Model(
name = "mixtral-8x22b", name = "mixtral-8x22b",
@ -300,7 +300,7 @@ wizardlm_2_8x22b = Model(
### Google DeepMind ### ### Google DeepMind ###
# gemini # gemini
gemini = Model( gemini = Model(
name = 'gemini', name = 'gemini-2.0',
base_provider = 'Google', base_provider = 'Google',
best_provider = Gemini best_provider = Gemini
) )
@ -316,13 +316,13 @@ gemini_exp = Model(
gemini_1_5_flash = Model( gemini_1_5_flash = Model(
name = 'gemini-1.5-flash', name = 'gemini-1.5-flash',
base_provider = 'Google DeepMind', base_provider = 'Google DeepMind',
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots]) best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
) )
gemini_1_5_pro = Model( gemini_1_5_pro = Model(
name = 'gemini-1.5-pro', name = 'gemini-1.5-pro',
base_provider = 'Google DeepMind', base_provider = 'Google DeepMind',
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots]) best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
) )
# gemini-2.0 # gemini-2.0
@ -713,6 +713,7 @@ class ModelUtils:
### Google ### ### Google ###
### Gemini ### Gemini
"gemini": gemini,
gemini.name: gemini, gemini.name: gemini,
gemini_exp.name: gemini_exp, gemini_exp.name: gemini_exp,
gemini_1_5_pro.name: gemini_1_5_pro, gemini_1_5_pro.name: gemini_1_5_pro,
@ -812,7 +813,6 @@ class ModelUtils:
demo_models = { demo_models = {
gpt_4o.name: [gpt_4o, [PollinationsAI, Blackbox]],
"default": [llama_3_2_11b, [HuggingFace]], "default": [llama_3_2_11b, [HuggingFace]],
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]], qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
qvq_72b.name: [qvq_72b, [HuggingSpace]], qvq_72b.name: [qvq_72b, [HuggingSpace]],

View file

@ -409,6 +409,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
def get_cache_file(cls) -> Path: def get_cache_file(cls) -> Path:
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json" return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@classmethod
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
elif cache_file.exists():
cache_file.unlink()
@classmethod @classmethod
def create_completion( def create_completion(
cls, cls,
@ -416,35 +424,25 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
auth_result = AuthResult() auth_result: AuthResult = None
cache_file = cls.get_cache_file() cache_file = cls.get_cache_file()
try: try:
if cache_file.exists(): if cache_file.exists():
with cache_file.open("r") as f: with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f)) auth_result = AuthResult(**json.load(f))
else: else:
auth_result = cls.on_auth(**kwargs) raise MissingAuthError
for chunk in auth_result:
if hasattr(chunk, "get_dict"):
auth_result = chunk
else:
yield chunk
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
except (MissingAuthError, NoValidHarFileError): except (MissingAuthError, NoValidHarFileError):
auth_result = cls.on_auth(**kwargs) auth_result = cls.on_auth(**kwargs)
for chunk in auth_result: for chunk in auth_result:
if hasattr(chunk, "get_dict"): if isinstance(chunk, AuthResult):
auth_result = chunk auth_result = chunk
else: else:
yield chunk yield chunk
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
finally: finally:
if hasattr(auth_result, "get_dict"): cls.write_cache_file(cache_file, auth_result)
data = auth_result.get_dict()
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(data))
elif cache_file.exists():
cache_file.unlink()
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@ -453,19 +451,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
auth_result: AuthResult = None
cache_file = cls.get_cache_file()
try: try:
auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
if cache_file.exists(): if cache_file.exists():
with cache_file.open("r") as f: with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f)) auth_result = AuthResult(**json.load(f))
else: else:
auth_result = cls.on_auth_async(**kwargs) raise MissingAuthError
async for chunk in auth_result:
if hasattr(chunk, "get_dict"):
auth_result = chunk
else:
yield chunk
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response: async for chunk in response:
yield chunk yield chunk
@ -474,16 +467,16 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
cache_file.unlink() cache_file.unlink()
auth_result = cls.on_auth_async(**kwargs) auth_result = cls.on_auth_async(**kwargs)
async for chunk in auth_result: async for chunk in auth_result:
if hasattr(chunk, "get_dict"): if isinstance(chunk, AuthResult):
auth_result = chunk auth_result = chunk
else: else:
yield chunk yield chunk
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response: async for chunk in response:
if cache_file is not None:
cls.write_cache_file(cache_file, auth_result)
cache_file = None
yield chunk yield chunk
finally: finally:
if hasattr(auth_result, "get_dict"): if cache_file is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True) cls.write_cache_file(cache_file, auth_result)
cache_file.write_text(json.dumps(auth_result.get_dict()))
elif cache_file.exists():
cache_file.unlink()

View file

@ -19,7 +19,6 @@ def quote_url(url: str) -> str:
def quote_title(title: str) -> str: def quote_title(title: str) -> str:
if title: if title:
title = title.strip()
title = " ".join(title.split()) title = " ".join(title.split())
return title.replace('[', '').replace(']', '') return title.replace('[', '').replace(']', '')
return "" return ""
@ -154,6 +153,7 @@ class Sources(ResponseType):
self.add_source(source) self.add_source(source)
def add_source(self, source: dict[str, str]): def add_source(self, source: dict[str, str]):
source = source if isinstance(source, dict) else {"url": source}
url = source.get("url", source.get("link", None)) url = source.get("url", source.get("link", None))
if url is not None: if url is not None:
url = re.sub(r"[&?]utm_source=.+", "", url) url = re.sub(r"[&?]utm_source=.+", "", url)

View file

@ -65,7 +65,7 @@ class IterListProvider(BaseRetryProvider):
return return
except Exception as e: except Exception as e:
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}") debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
if started: if started:
raise e raise e
yield e yield e
@ -105,7 +105,7 @@ class IterListProvider(BaseRetryProvider):
return return
except Exception as e: except Exception as e:
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}") debug.error(name=f"{provider.__name__} {type(e).__name__}: {e}")
if started: if started:
raise e raise e
yield e yield e

View file

@ -2,12 +2,12 @@ from __future__ import annotations
from curl_cffi.requests import AsyncSession, Response from curl_cffi.requests import AsyncSession, Response
try: try:
from curl_cffi.requests import CurlMime from curl_cffi import CurlMime
has_curl_mime = True has_curl_mime = True
except ImportError: except ImportError:
has_curl_mime = False has_curl_mime = False
try: try:
from curl_cffi.requests import CurlWsFlag from curl_cffi import CurlWsFlag
has_curl_ws = True has_curl_ws = True
except ImportError: except ImportError:
has_curl_ws = False has_curl_ws = False
@ -73,7 +73,7 @@ class StreamSession(AsyncSession):
def request( def request(
self, method: str, url: str, ssl = None, **kwargs self, method: str, url: str, ssl = None, **kwargs
) -> StreamResponse: ) -> StreamResponse:
if isinstance(kwargs.get("data"), CurlMime): if kwargs.get("data") and isinstance(kwargs.get("data"), CurlMime):
kwargs["multipart"] = kwargs.pop("data") kwargs["multipart"] = kwargs.pop("data")
"""Create and return a StreamResponse object for the given HTTP request.""" """Create and return a StreamResponse object for the given HTTP request."""
return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs)) return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs))
@ -100,12 +100,12 @@ if has_curl_mime:
else: else:
class FormData(): class FormData():
def __init__(self) -> None: def __init__(self) -> None:
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]") raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U curl_cffi")
class WebSocket(): class WebSocket():
def __init__(self, session, url, **kwargs) -> None: def __init__(self, session, url, **kwargs) -> None:
if not has_curl_ws: if not has_curl_ws:
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U g4f[curl_cffi]") raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U curl_cffi")
self.session: StreamSession = session self.session: StreamSession = session
self.url: str = url self.url: str = url
del kwargs["autoping"] del kwargs["autoping"]
@ -116,11 +116,13 @@ class WebSocket():
return self return self
async def __aexit__(self, *args): async def __aexit__(self, *args):
await self.inner.aclose() await self.inner.aclose() if hasattr(self.inner, "aclose") else await self.inner.close()
async def receive_str(self, **kwargs) -> str: async def receive_str(self, **kwargs) -> str:
bytes, _ = await self.inner.arecv() method = self.inner.arecv if hasattr(self.inner, "arecv") else self.inner.recv
bytes, _ = await method()
return bytes.decode(errors="ignore") return bytes.decode(errors="ignore")
async def send_str(self, data: str): async def send_str(self, data: str):
await self.inner.asend(data.encode(), CurlWsFlag.TEXT) method = self.inner.asend if hasattr(self.inner, "asend") else self.inner.send
await method(data.encode(), CurlWsFlag.TEXT)

71
g4f/tools/pydantic_ai.py Normal file
View file

@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Optional
from functools import partial
from dataclasses import dataclass, field
from pydantic_ai.models import Model, KnownModelName, infer_model
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
from ..client import AsyncClient
@dataclass(init=False)
class AIModel(OpenAIModel):
"""A model that uses the G4F API."""
client: AsyncClient = field(repr=False)
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
_model_name: str = field(repr=False)
_provider: str = field(repr=False)
_system: Optional[str] = field(repr=False)
def __init__(
self,
model_name: str,
provider: str | None = None,
*,
system_prompt_role: OpenAISystemPromptRole | None = None,
system: str | None = 'openai',
**kwargs
):
"""Initialize an AI model.
Args:
model_name: The name of the AI model to use. List of model names available
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
In the future, this may be inferred from the model name.
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
customize the `base_url` and `api_key` to use a different provider.
"""
self._model_name = model_name
self._provider = provider
self.client = AsyncClient(provider=provider, **kwargs)
self.system_prompt_role = system_prompt_role
self._system = system
def name(self) -> str:
if self._provider:
return f'g4f:{self._provider}:{self._model_name}'
return f'g4f:{self._model_name}'
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
if isinstance(model, Model):
return model
if model.startswith("g4f:"):
model = model[4:]
if ":" in model:
provider, model = model.split(":", 1)
return AIModel(model, provider=provider, api_key=api_key)
return AIModel(model)
return infer_model(model)
def apply_patch(api_key: str | None = None):
import pydantic_ai.models
import pydantic_ai.models.openai
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
pydantic_ai.models.AIModel = AIModel
pydantic_ai.models.openai.NOT_GIVEN = None

View file

@ -44,7 +44,7 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
messages[-1]["content"] = await do_search(messages[-1]["content"], web_search) messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
except Exception as e: except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Keep web_search in kwargs for provider native support # Keep web_search in kwargs for provider native support
pass pass
@ -82,7 +82,8 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
has_bucket = True has_bucket = True
message["content"] = new_message_content message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str): if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS if "\nSource: " in messages[-1]["content"]:
messages[-1]["content"] += BUCKET_INSTRUCTIONS
create_function = provider.get_async_create_function() create_function = provider.get_async_create_function()
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs)) response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
async for chunk in response: async for chunk in response:
@ -149,7 +150,7 @@ def iter_run_tools(
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search)) messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
except Exception as e: except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Keep web_search in kwargs for provider native support # Keep web_search in kwargs for provider native support
pass pass
@ -192,7 +193,8 @@ def iter_run_tools(
has_bucket = True has_bucket = True
message["content"] = new_message_content message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str): if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS if "\nSource: " in messages[-1]["content"]:
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
thinking_start_time = 0 thinking_start_time = 0
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs): for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):

View file

@ -237,7 +237,7 @@ def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) ->
except (DuckDuckGoSearchException, MissingRequirementsError) as e: except (DuckDuckGoSearchException, MissingRequirementsError) as e:
if raise_search_exceptions: if raise_search_exceptions:
raise e raise e
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
return prompt return prompt
def spacy_get_keywords(text: str): def spacy_get_keywords(text: str):

View file

@ -44,8 +44,9 @@ def get_github_version(repo: str) -> str:
VersionNotFoundError: If there is an error in fetching the version from GitHub. VersionNotFoundError: If there is an error in fetching the version from GitHub.
""" """
try: try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json() response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
return response["tag_name"] response.raise_for_status()
return response.json()["tag_name"]
except requests.RequestException as e: except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}") raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")