Merge pull request #2530 from hlohaus/cont

Add Anthropic provider
This commit is contained in:
H Lohaus 2025-01-03 02:56:06 +01:00 committed by GitHub
commit c5ba78c7e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 800 additions and 161 deletions

View file

@ -70,8 +70,7 @@ def read_json(text: str) -> dict:
try: try:
return json.loads(text.strip()) return json.loads(text.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
print("No valid json:", text) raise RuntimeError(f"Invalid JSON: {text}")
return {}
def read_text(text: str) -> str: def read_text(text: str) -> str:
""" """
@ -86,7 +85,8 @@ def read_text(text: str) -> str:
match = re.search(r"```(markdown|)\n(?P<text>[\S\s]+?)\n```", text) match = re.search(r"```(markdown|)\n(?P<text>[\S\s]+?)\n```", text)
if match: if match:
return match.group("text") return match.group("text")
return text else:
raise RuntimeError(f"Invalid markdown: {text}")
def get_ai_response(prompt: str, as_json: bool = True) -> Union[dict, str]: def get_ai_response(prompt: str, as_json: bool = True) -> Union[dict, str]:
""" """
@ -197,6 +197,7 @@ def create_review_prompt(pull: PullRequest, diff: str):
return f"""Your task is to review a pull request. Instructions: return f"""Your task is to review a pull request. Instructions:
- Write in name of g4f copilot. Don't use placeholder. - Write in name of g4f copilot. Don't use placeholder.
- Write the review in GitHub Markdown format. - Write the review in GitHub Markdown format.
- Enclose your response in backticks ```response```
- Thank the author for contributing to the project. - Thank the author for contributing to the project.
Pull request author: {pull.user.name} Pull request author: {pull.user.name}

View file

@ -16,8 +16,6 @@ async def test_async(provider: ProviderType):
return False return False
messages = [{"role": "user", "content": "Hello Assistant!"}] messages = [{"role": "user", "content": "Hello Assistant!"}]
try: try:
if "webdriver" in provider.get_parameters():
return False
response = await asyncio.wait_for(ChatCompletion.create_async( response = await asyncio.wait_for(ChatCompletion.create_async(
model=models.default, model=models.default,
messages=messages, messages=messages,

View file

@ -46,4 +46,4 @@ class TestBackendApi(unittest.TestCase):
self.skipTest(e) self.skipTest(e)
except MissingRequirementsError: except MissingRequirementsError:
self.skipTest("search is not installed") self.skipTest("search is not installed")
self.assertTrue(len(result) >= 4) self.assertGreater(len(result), 0)

View file

@ -1,36 +1,11 @@
import unittest import unittest
import asyncio
import g4f import g4f.version
from g4f import ChatCompletion, get_last_provider
from g4f.errors import VersionNotFoundError from g4f.errors import VersionNotFoundError
from g4f.Provider import RetryProvider
from .mocks import ProviderMock
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestGetLastProvider(unittest.TestCase): class TestGetLastProvider(unittest.TestCase):
def test_get_last_provider(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_retry(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_async(self):
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
asyncio.run(coroutine)
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_as_dict(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
last_provider_dict = get_last_provider(True)
self.assertIsInstance(last_provider_dict, dict)
self.assertIn('name', last_provider_dict)
self.assertEqual(ProviderMock.__name__, last_provider_dict['name'])
def test_get_latest_version(self): def test_get_latest_version(self):
try: try:
self.assertIsInstance(g4f.version.utils.current_version, str) self.assertIsInstance(g4f.version.utils.current_version, str)

View file

@ -35,7 +35,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True
default_model = "gpt-4o-mini" default_model = "llama-3.1-70b-chat"
default_image_model = "flux" default_image_model = "flux"
models = [] models = []
@ -113,7 +113,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod @classmethod
def get_model(cls, model: str) -> str: def get_model(cls, model: str) -> str:
"""Get the actual model name from alias""" """Get the actual model name from alias"""
return cls.model_aliases.get(model, model) return cls.model_aliases.get(model, model or cls.default_model)
@classmethod @classmethod
async def check_api_key(cls, api_key: str) -> bool: async def check_api_key(cls, api_key: str) -> bool:
@ -162,6 +162,9 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
""" """
Filters the full response to remove system errors and other unwanted text. Filters the full response to remove system errors and other unwanted text.
""" """
if "Model not found or too long input. Or any other error (xD)" in response:
raise ValueError(response)
filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored
filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token
filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1 filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1

View file

@ -246,6 +246,8 @@ class BlackboxCreateAgent(AsyncGeneratorProvider, ProviderModelMixin):
Returns: Returns:
AsyncResult: The response from the provider AsyncResult: The response from the provider
""" """
if not model:
model = cls.default_model
if model in cls.chat_models: if model in cls.chat_models:
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs): async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
return text return text

View file

@ -80,4 +80,6 @@ class ChatGptEs(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(cls.api_endpoint, headers=headers, data=payload) as response: async with session.post(cls.api_endpoint, headers=headers, data=payload) as response:
response.raise_for_status() response.raise_for_status()
result = await response.json() result = await response.json()
if "Du musst das Kästchen anklicken!" in result['data']:
raise ValueError(result['data'])
yield result['data'] yield result['data']

View file

@ -127,8 +127,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
response = session.post(cls.conversation_url) response = session.post(cls.conversation_url)
raise_for_status(response) raise_for_status(response)
conversation_id = response.json().get("id") conversation_id = response.json().get("id")
if return_conversation:
conversation = Conversation(conversation_id) conversation = Conversation(conversation_id)
if return_conversation:
yield conversation yield conversation
if prompt is None: if prompt is None:
prompt = format_prompt_max_length(messages, 10000) prompt = format_prompt_max_length(messages, 10000)

View file

@ -15,7 +15,8 @@ from .needs_auth.OpenaiAPI import OpenaiAPI
""" """
class Mhystical(OpenaiAPI): class Mhystical(OpenaiAPI):
url = "https://api.mhystical.cc" label = "Mhystical"
url = "https://mhystical.cc"
api_endpoint = "https://api.mhystical.cc/v1/completions" api_endpoint = "https://api.mhystical.cc/v1/completions"
working = True working = True
needs_auth = False needs_auth = False

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import random import random
import requests import requests
from urllib.parse import quote
from typing import Optional from typing import Optional
from aiohttp import ClientSession from aiohttp import ClientSession
@ -170,7 +171,7 @@ class PollinationsAI(OpenaiAPI):
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
prompt = quote(messages[-1]["content"]) prompt = quote(messages[-1]["content"] if prompt is None else prompt)
param_string = "&".join(f"{k}={v}" for k, v in params.items()) param_string = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}" url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}"

View file

@ -0,0 +1,195 @@
from __future__ import annotations
import requests
import json
import base64
from typing import Optional
from ..helper import filter_none
from ...typing import AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage
from ...errors import MissingAuthError
from ...image import to_bytes, is_accepted_format
from .OpenaiAPI import OpenaiAPI
class Anthropic(OpenaiAPI):
label = "Anthropic API"
url = "https://console.anthropic.com"
login_url = "https://console.anthropic.com/settings/keys"
working = True
api_base = "https://api.anthropic.com/v1"
needs_auth = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "claude-3-5-sonnet-latest"
models = [
default_model,
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-latest",
"claude-3-5-haiku-20241022",
"claude-3-opus-latest",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
]
models_aliases = {
"claude-3.5-sonnet": default_model,
"claude-3-opus": "claude-3-opus-latest",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
}
@classmethod
def get_models(cls, api_key: str = None, **kwargs):
if not cls.models:
url = f"https://api.anthropic.com/v1/models"
response = requests.get(url, headers={
"Content-Type": "application/json",
"x-api-key": api_key,
"anthropic-version": "2023-06-01"
})
raise_for_status(response)
models = response.json()
cls.models = [model["id"] for model in models["data"]]
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
images: ImagesType = None,
api_key: str = None,
temperature: float = None,
max_tokens: int = 4096,
top_k: int = None,
top_p: float = None,
stop: list[str] = None,
stream: bool = False,
headers: dict = None,
impersonate: str = None,
tools: Optional[list] = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')
if images is not None:
insert_images = []
for image, _ in images:
data = to_bytes(image)
insert_images.append({
"type": "image",
"source": {
"type": "base64",
"media_type": is_accepted_format(data),
"data": base64.b64encode(data).decode(),
}
})
messages[-1]["content"] = [
*insert_images,
{
"type": "text",
"text": messages[-1]["content"]
}
]
system = "\n".join([message for message in messages if message.get("role") == "system"])
if system:
messages = [message for message in messages if message.get("role") != "system"]
else:
system = None
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
timeout=timeout,
impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model, api_key=api_key),
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
stop_sequences=stop,
system=system,
stream=stream,
tools=tools,
**extra_data
)
async with session.post(f"{cls.api_base}/messages", json=data) as response:
await raise_for_status(response)
if not stream:
data = await response.json()
cls.raise_error(data)
if "type" in data and data["type"] == "message":
for content in data["content"]:
if content["type"] == "text":
yield content["text"]
elif content["type"] == "tool_use":
tool_calls.append({
"id": content["id"],
"type": "function",
"function": { "name": content["name"], "arguments": content["input"] }
})
if data["stop_reason"] == "end_turn":
yield FinishReason("stop")
elif data["stop_reason"] == "max_tokens":
yield FinishReason("length")
yield Usage(**data["usage"])
else:
content_block = None
partial_json = []
tool_calls = []
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
data = json.loads(chunk)
cls.raise_error(data)
if "type" in data:
if data["type"] == "content_block_start":
content_block = data["content_block"]
if content_block is None:
pass # Message start
elif data["type"] == "content_block_delta":
if content_block["type"] == "text":
yield data["delta"]["text"]
elif content_block["type"] == "tool_use":
partial_json.append(data["delta"]["partial_json"])
elif data["type"] == "message_delta":
if data["delta"]["stop_reason"] == "end_turn":
yield FinishReason("stop")
elif data["delta"]["stop_reason"] == "max_tokens":
yield FinishReason("length")
yield Usage(**data["usage"])
elif data["type"] == "content_block_stop":
if content_block["type"] == "tool_use":
tool_calls.append({
"id": content_block["id"],
"type": "function",
"function": { "name": content_block["name"], "arguments": partial_json.join("") }
})
partial_json = []
if tool_calls:
yield ToolCalls(tool_calls)
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
"Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
**(
{"x-api-key": api_key}
if api_key is not None else {}
),
"anthropic-version": "2023-06-01",
**({} if headers is None else headers)
}

View file

@ -10,6 +10,7 @@ from ...cookies import get_cookies
class Cerebras(OpenaiAPI): class Cerebras(OpenaiAPI):
label = "Cerebras Inference" label = "Cerebras Inference"
url = "https://inference.cerebras.ai/" url = "https://inference.cerebras.ai/"
login_url = "https://cloud.cerebras.ai"
api_base = "https://api.cerebras.ai/v1" api_base = "https://api.cerebras.ai/v1"
working = True working = True
default_model = "llama3.1-70b" default_model = "llama3.1-70b"

View file

@ -16,6 +16,7 @@ from ... import debug
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API" label = "Google Gemini API"
url = "https://ai.google.dev" url = "https://ai.google.dev"
login_url = "https://aistudio.google.com/u/0/apikey"
api_base = "https://generativelanguage.googleapis.com/v1beta" api_base = "https://generativelanguage.googleapis.com/v1beta"
working = True working = True
@ -24,7 +25,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
default_model = "gemini-1.5-pro" default_model = "gemini-1.5-pro"
default_vision_model = default_model default_vision_model = default_model
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
model_aliases = { model_aliases = {
"gemini-flash": "gemini-1.5-flash", "gemini-flash": "gemini-1.5-flash",
"gemini-flash": "gemini-1.5-flash-8b", "gemini-flash": "gemini-1.5-flash-8b",

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class Groq(OpenaiAPI): class Groq(OpenaiAPI):
label = "Groq" label = "Groq"
url = "https://console.groq.com/playground" url = "https://console.groq.com/playground"
login_url = "https://console.groq.com/keys"
api_base = "https://api.groq.com/openai/v1" api_base = "https://api.groq.com/openai/v1"
working = True working = True
default_model = "mixtral-8x7b-32768" default_model = "mixtral-8x7b-32768"

View file

@ -47,8 +47,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None, proxy: str = None,
api_base: str = "https://api-inference.huggingface.co", api_base: str = "https://api-inference.huggingface.co",
api_key: str = None, api_key: str = None,
max_new_tokens: int = 1024, max_tokens: int = 1024,
temperature: float = 0.7, temperature: float = None,
prompt: str = None, prompt: str = None,
action: str = None, action: str = None,
extra_data: dict = {}, extra_data: dict = {},
@ -84,7 +84,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
else: else:
params = { params = {
"return_full_text": False, "return_full_text": False,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
**extra_data **extra_data
} }

View file

@ -6,6 +6,7 @@ from .HuggingChat import HuggingChat
class HuggingFaceAPI(OpenaiAPI): class HuggingFaceAPI(OpenaiAPI):
label = "HuggingFace (Inference API)" label = "HuggingFace (Inference API)"
url = "https://api-inference.huggingface.co" url = "https://api-inference.huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
api_base = "https://api-inference.huggingface.co/v1" api_base = "https://api-inference.huggingface.co/v1"
working = True working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"

View file

@ -4,7 +4,7 @@ import json
import requests import requests
from ..helper import filter_none from ..helper import filter_none
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage from ...providers.response import FinishReason, ToolCalls, Usage
@ -12,9 +12,10 @@ from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri from ...image import to_data_uri
from ... import debug from ... import debug
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
label = "OpenAI API" label = "OpenAI API"
url = "https://platform.openai.com" url = "https://platform.openai.com"
login_url = "https://platform.openai.com/settings/organization/api-keys"
api_base = "https://api.openai.com/v1" api_base = "https://api.openai.com/v1"
working = True working = True
needs_auth = True needs_auth = True
@ -141,18 +142,6 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
if "finish_reason" in choice and choice["finish_reason"] is not None: if "finish_reason" in choice and choice["finish_reason"] is not None:
return FinishReason(choice["finish_reason"]) return FinishReason(choice["finish_reason"])
@staticmethod
def raise_error(data: dict):
if "error_message" in data:
raise ResponseError(data["error_message"])
elif "error" in data:
if "code" in data["error"]:
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
elif "message" in data["error"]:
raise ResponseError(data["error"]["message"])
else:
raise ResponseError(data["error"])
@classmethod @classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return { return {

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class PerplexityApi(OpenaiAPI): class PerplexityApi(OpenaiAPI):
label = "Perplexity API" label = "Perplexity API"
url = "https://www.perplexity.ai" url = "https://www.perplexity.ai"
login_url = "https://www.perplexity.ai/settings/api"
working = True working = True
api_base = "https://api.perplexity.ai" api_base = "https://api.perplexity.ai"
default_model = "llama-3-sonar-large-32k-online" default_model = "llama-3-sonar-large-32k-online"

View file

@ -9,6 +9,7 @@ from ...errors import ResponseError, MissingAuthError
class Replicate(AsyncGeneratorProvider, ProviderModelMixin): class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com" url = "https://replicate.com"
login_url = "https://replicate.com/account/api-tokens"
working = True working = True
needs_auth = True needs_auth = True
default_model = "meta/meta-llama-3-70b-instruct" default_model = "meta/meta-llama-3-70b-instruct"
@ -25,7 +26,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None, proxy: str = None,
timeout: int = 180, timeout: int = 180,
system_prompt: str = None, system_prompt: str = None,
max_new_tokens: int = None, max_tokens: int = None,
temperature: float = None, temperature: float = None,
top_p: float = None, top_p: float = None,
top_k: float = None, top_k: float = None,
@ -55,7 +56,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
"prompt": format_prompt(messages), "prompt": format_prompt(messages),
**filter_none( **filter_none(
system_prompt=system_prompt, system_prompt=system_prompt,
max_new_tokens=max_new_tokens, max_new_tokens=max_tokens,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,

View file

@ -1,5 +1,6 @@
from .gigachat import * from .gigachat import *
from .Anthropic import Anthropic
from .BingCreateImages import BingCreateImages from .BingCreateImages import BingCreateImages
from .Cerebras import Cerebras from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount from .CopilotAccount import CopilotAccount

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class glhfChat(OpenaiAPI): class glhfChat(OpenaiAPI):
label = "glhf.chat" label = "glhf.chat"
url = "https://glhf.chat" url = "https://glhf.chat"
login_url = "https://glhf.chat/users/settings/api"
api_base = "https://glhf.chat/api/openai/v1" api_base = "https://glhf.chat/api/openai/v1"
working = True working = True
model_aliases = { model_aliases = {

View file

@ -5,5 +5,6 @@ from .OpenaiAPI import OpenaiAPI
class xAI(OpenaiAPI): class xAI(OpenaiAPI):
label = "xAI" label = "xAI"
url = "https://console.x.ai" url = "https://console.x.ai"
login_url = "https://console.x.ai"
api_base = "https://api.x.ai/v1" api_base = "https://api.x.ai/v1"
working = True working = True

View file

@ -12,7 +12,7 @@ from .errors import StreamNotSupportedError
from .cookies import get_cookies, set_cookies from .cookies import get_cookies, set_cookies
from .providers.types import ProviderType from .providers.types import ProviderType
from .providers.helper import concat_chunks from .providers.helper import concat_chunks
from .client.service import get_model_and_provider, get_last_provider from .client.service import get_model_and_provider
#Configure "g4f" logger #Configure "g4f" logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,7 +47,8 @@ class ChatCompletion:
if ignore_stream: if ignore_stream:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
result = provider.create_completion(model, messages, stream=stream, **kwargs) create_method = provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion
result = create_method(model, messages, stream=stream, **kwargs)
return result if stream else concat_chunks(result) return result if stream else concat_chunks(result)
@ -72,7 +73,9 @@ class ChatCompletion:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
if stream: if stream:
if hasattr(provider, "create_async_generator"): if hasattr(provider, "create_async_authed_generator"):
return provider.create_async_authed_generator(model, messages, **kwargs)
elif hasattr(provider, "create_async_generator"):
return provider.create_async_generator(model, messages, **kwargs) return provider.create_async_generator(model, messages, **kwargs)
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"') raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')

View file

@ -197,10 +197,9 @@ class Api:
) )
def register_routes(self): def register_routes(self):
if not AppConfig.gui:
@self.app.get("/") @self.app.get("/")
async def read_root(): async def read_root():
if AppConfig.gui:
return RedirectResponse("/chat/", 302)
return RedirectResponse("/v1", 302) return RedirectResponse("/v1", 302)
@self.app.get("/v1") @self.app.get("/v1")

View file

@ -7,9 +7,9 @@ import string
import asyncio import asyncio
import aiohttp import aiohttp
import base64 import base64
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image import ImageResponse, copy_images, images_dir from ..image import ImageResponse, copy_images
from ..typing import Messages, ImageType from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
@ -22,7 +22,7 @@ from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
from .. import debug from .. import debug
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
@ -220,7 +220,7 @@ class Completions:
ignore_working: Optional[bool] = False, ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False, ignore_stream: Optional[bool] = False,
**kwargs **kwargs
) -> IterResponse: ) -> ChatCompletion:
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, model,
self.provider if provider is None else provider, self.provider if provider is None else provider,
@ -236,7 +236,7 @@ class Completions:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
response = iter_run_tools( response = iter_run_tools(
provider.create_completion, provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion,
model, model,
messages, messages,
stream=stream, stream=stream,
@ -248,9 +248,6 @@ class Completions:
), ),
**kwargs **kwargs
) )
if asyncio.iscoroutinefunction(provider.create_completion):
# Run the asynchronous function in an event loop
response = asyncio.run(response)
if stream and hasattr(response, '__aiter__'): if stream and hasattr(response, '__aiter__'):
# It's an async generator, wrap it into a sync iterator # It's an async generator, wrap it into a sync iterator
response = to_sync_generator(response) response = to_sync_generator(response)
@ -264,6 +261,14 @@ class Completions:
else: else:
return next(response) return next(response)
def stream(
self,
messages: Messages,
model: str,
**kwargs
) -> IterResponse:
return self.create(messages, model, stream=True, **kwargs)
class Chat: class Chat:
completions: Completions completions: Completions
@ -507,7 +512,7 @@ class AsyncCompletions:
ignore_working: Optional[bool] = False, ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False, ignore_stream: Optional[bool] = False,
**kwargs **kwargs
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]: ) -> Awaitable[ChatCompletion]:
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, model,
self.provider if provider is None else provider, self.provider if provider is None else provider,
@ -521,6 +526,8 @@ class AsyncCompletions:
kwargs["images"] = [(image, image_name)] kwargs["images"] = [(image, image_name)]
if ignore_stream: if ignore_stream:
kwargs["ignore_stream"] = True kwargs["ignore_stream"] = True
if hasattr(provider, "create_async_authed_generator"):
create_handler = provider.create_async_authed_generator
if hasattr(provider, "create_async_generator"): if hasattr(provider, "create_async_generator"):
create_handler = provider.create_async_generator create_handler = provider.create_async_generator
else: else:
@ -538,10 +545,20 @@ class AsyncCompletions:
), ),
**kwargs **kwargs
) )
if not hasattr(response, '__aiter__'):
response = to_async_iterator(response)
response = async_iter_response(response, stream, response_format, max_tokens, stop) response = async_iter_response(response, stream, response_format, max_tokens, stop)
response = async_iter_append_model_and_provider(response, model, provider) response = async_iter_append_model_and_provider(response, model, provider)
return response if stream else anext(response) return response if stream else anext(response)
def stream(
self,
messages: Messages,
model: str,
**kwargs
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
return self.create(messages, model, stream=True, **kwargs)
class AsyncImages(Images): class AsyncImages(Images):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None): def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client self.client: AsyncClient = client

View file

@ -5,6 +5,22 @@ import logging
from typing import AsyncIterator, Iterator, AsyncGenerator, Optional from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
def filter_markdown(text: str, allowd_types=None, default=None) -> str:
"""
Parses code block from a string.
Args:
text (str): A string containing a code block.
Returns:
dict: A dictionary parsed from the code block.
"""
match = re.search(r"```(.+)\n(?P<code>[\S\s]+?)(\n```|$)", text)
if match:
if allowd_types is None or match.group(1) in allowd_types:
return match.group("code")
return default
def filter_json(text: str) -> str: def filter_json(text: str) -> str:
""" """
Parses JSON code block from a string. Parses JSON code block from a string.
@ -15,10 +31,7 @@ def filter_json(text: str) -> str:
Returns: Returns:
dict: A dictionary parsed from the JSON code block. dict: A dictionary parsed from the JSON code block.
""" """
match = re.search(r"```(json|)\n(?P<code>[\S\s]+?)\n```", text) return filter_markdown(text, ["", "json"], text)
if match:
return match.group("code")
return text
def find_stop(stop: Optional[list[str]], content: str, chunk: str = None): def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
first = -1 first = -1

230
g4f/gui/client/home.html Normal file
View file

@ -0,0 +1,230 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>G4F GUI</title>
<style>
:root {
--colour-1: #000000;
--colour-2: #ccc;
--colour-3: #e4d4ff;
--colour-4: #f0f0f0;
--colour-5: #181818;
--colour-6: #242424;
--accent: #8b3dff;
--gradient: #1a1a1a;
--background: #16101b;
--size: 70vw;
--top: 50%;
--blur: 40px;
--opacity: 0.6;
}
@import url("https://fonts.googleapis.com/css2?family=Inter:wght@100;200;300;400;500;600;700;800;900&display=swap");
.gradient {
position: absolute;
z-index: -1;
left: 50vw;
border-radius: 50%;
background: radial-gradient(circle at center, var(--accent), var(--gradient));
width: var(--size);
height: var(--size);
top: var(--top);
transform: translate(-50%, -50%);
filter: blur(var(--blur)) opacity(var(--opacity));
animation: zoom_gradient 6s infinite alternate;
display: none;
max-height: 100%;
transition: max-height 0.25s ease-in;
}
.gradient.hidden {
max-height: 0;
transition: max-height 0.15s ease-out;
}
@media only screen and (min-width: 40em) {
body .gradient{
display: block;
}
}
@keyframes zoom_gradient {
0% {
transform: translate(-50%, -50%) scale(1);
}
100% {
transform: translate(-50%, -50%) scale(1.2);
}
}
/* Body and text color */
body {
background: var(--background);
color: var(--colour-3);
font-family: "Inter", sans-serif;
height: 100vh;
margin: 0;
padding: 0;
overflow: hidden;
font-weight: bold;
}
/* Container for the main content */
.container {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 100%;
text-align: center;
z-index: 1;
}
header {
font-size: 3rem;
text-transform: uppercase;
margin: 20px;
color: var(--colour-4);
}
iframe {
background: transparent;
width: 100%;
border: none;
}
#background {
height: 100%;
position: absolute;
z-index: -1;
}
iframe.stream {
max-height: 0;
transition: max-height 0.15s ease-out;
}
iframe.stream.show {
max-height: 1000px;
height: 1000px;
transition: max-height 0.25s ease-in;
background: rgba(255,255,255,0.7);
border-top: 2px solid rgba(255,255,255,0.5);
}
.description {
font-size: 1.2rem;
margin-bottom: 30px;
color: var(--colour-2);
} return app
.input-field {
width: 80%;
max-width: 400px;
padding: 12px;
margin: 10px 0;
border: 2px solid var(--colour-6);
background-color: var(--colour-5);
color: var(--colour-3);
border-radius: 8px;
font-size: 1.1rem;
}
.input-field:focus {
outline: none;
border-color: var(--accent);
}
.button {
background-color: var(--accent);
color: var(--colour-3);
border: none;
padding: 15px 30px;
font-size: 1.1rem;
border-radius: 8px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-top: 15px;
width: 100%;
max-width: 400px;
font-weight: bold;
}
.button:hover {
background-color: #7a2ccd;
}
.footer {
margin-top: 30px;
font-size: 0.9rem;
color: var(--colour-2);
}
/* Animation for the gradient circle */
@keyframes zoom_gradient {
0% {
transform: translate(-50%, -50%) scale(1);
}
100% {
transform: translate(-50%, -50%) scale(1.5);
}
}
</style>
</head>
<body>
<iframe id="background"></iframe>
<!-- Gradient Background Circle -->
<div class="gradient"></div>
<!-- Main Content -->
<div class="container">
<header>
G4F GUI
</header>
<div class="description">
Welcome to the G4F GUI! <br>
Your AI assistant is ready to assist you.
</div>
<!-- Input and Button -->
<form action="/chat/">
<!--
<input type="text" name="prompt" class="input-field" placeholder="Enter your query...">
-->
<button class="button">Open Chat</button>
</form>
<!-- Footer -->
<div class="footer">
<p>&copy; 2025 G4F. All Rights Reserved.</p>
<p>Powered by the G4F framework</p>
</div>
<iframe id="stream-widget" class="stream" data-src="/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=news in " class="" frameborder="0"></iframe>
</div>
<script>
const iframe = document.getElementById('stream-widget');
iframe.src = iframe.dataset.src + navigator.language;
setTimeout(()=>iframe.classList.add('show'), 5000);
(async () => {
const prompt = `
Today is ${new Date().toJSON().slice(0, 10)}.
Create a single-page HTML screensaver reflecting the current season (based on the date).
For example, if it's Spring, it might use floral patterns or pastel colors.
Avoid using any text. Consider a subtle animation or transition effect.`;
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html`)
const text = await response.text()
background.src = `data:text/html;charset=utf-8,${encodeURIComponent(text)}`;
const gradient = document.querySelector('.gradient');
gradient.classList.add('hidden');
})();
</script>
</body>
</html>

View file

@ -157,6 +157,10 @@
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label> <label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea> <textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
</div> </div>
<div class="field box hidden">
<label for="Anthropic-api_key" class="label" title="">Anthropic API:</label>
<textarea id="Anthropic-api_key" name="Anthropic[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box hidden"> <div class="field box hidden">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label> <label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea> <textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>

View file

@ -239,7 +239,7 @@ body:not(.white) a:visited{
max-width: 210px; max-width: 210px;
margin-left: auto; margin-left: auto;
margin-right: 8px; margin-right: 8px;
margin-bottom: 8px; margin-top: 12px;
} }
.convo-title { .convo-title {
@ -530,6 +530,7 @@ body:not(.white) a:visited{
z-index: 100000; z-index: 100000;
top: 0; top: 0;
right: 0; right: 0;
animation: show_popup 0.4s;
} }
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button, button.options_button { .stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button, button.options_button {
@ -545,7 +546,6 @@ body:not(.white) a:visited{
align-items: center; align-items: center;
gap: 12px; gap: 12px;
cursor: pointer; cursor: pointer;
animation: show_popup 0.4s;
height: 28px; height: 28px;
} }
@ -712,6 +712,8 @@ form label.toogle {
position: relative; position: relative;
overflow: hidden; overflow: hidden;
transition: 0.33s; transition: 0.33s;
min-width: 60px;
margin-left: 0;
} }
.buttons label:after, .buttons label:after,
@ -1280,12 +1282,18 @@ ul {
border: 1px solid #e4d4ffc9; border: 1px solid #e4d4ffc9;
} }
.settings textarea, form textarea { form textarea {
height: 20px; height: 20px;
min-height: 20px; min-height: 20px;
padding: 0; padding: 0;
} }
.settings textarea {
height: 30px;
min-height: 30px;
padding: 6px;
}
form .field .fa-xmark { form .field .fa-xmark {
line-height: 20px; line-height: 20px;
cursor: pointer; cursor: pointer;
@ -1346,6 +1354,7 @@ form .field.saved .fa-xmark {
.settings .label, form .label, .settings label, form label { .settings .label, form .label, .settings label, form label {
font-size: 15px; font-size: 15px;
margin-left: var(--inner-gap); margin-left: var(--inner-gap);
min-width: 120px;
} }
.settings .label, form .label { .settings .label, form .label {

View file

@ -137,14 +137,13 @@ class HtmlRenderPlugin {
} }
} }
} }
if (window.hljs) {
hljs.addPlugin(new HtmlRenderPlugin())
hljs.addPlugin(new CopyButtonPlugin());
}
let typesetPromise = Promise.resolve(); let typesetPromise = Promise.resolve();
const highlight = (container) => { const highlight = (container) => {
if (window.hljs) { if (window.hljs) {
hljs.addPlugin(new HtmlRenderPlugin())
if (window.CopyButtonPlugin) {
hljs.addPlugin(new CopyButtonPlugin());
}
container.querySelectorAll('code:not(.hljs').forEach((el) => { container.querySelectorAll('code:not(.hljs').forEach((el) => {
if (el.className != "hljs") { if (el.className != "hljs") {
hljs.highlightElement(el); hljs.highlightElement(el);
@ -518,6 +517,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) =>
delete new_message.synthesize; delete new_message.synthesize;
delete new_message.finish; delete new_message.finish;
delete new_message.conversation; delete new_message.conversation;
delete new_message.continue;
// Append message to new messages // Append message to new messages
new_messages.push(new_message) new_messages.push(new_message)
} }
@ -571,7 +571,7 @@ async function load_provider_parameters(provider) {
field_el = document.createElement("div"); field_el = document.createElement("div");
field_el.classList.add("field"); field_el.classList.add("field");
field_el.classList.add("box"); field_el.classList.add("box");
if (typeof value == "object") { if (typeof value == "object" && value != null) {
value = JSON.stringify(value, null, 4); value = JSON.stringify(value, null, 4);
} }
if (saved_value) { if (saved_value) {
@ -580,14 +580,15 @@ async function load_provider_parameters(provider) {
saved_value = value; saved_value = value;
} }
let placeholder; let placeholder;
if (key in ["api_key", "proof_token"]) { if (["api_key", "proof_token"].includes(key)) {
placeholder = value.length >= 22 ? (value.substring(0, 10) + "*".repeat(8) + value.substring(value.length-10)) : value; placeholder = saved_value && saved_value.length >= 22 ? (saved_value.substring(0, 12) + "*".repeat(12) + saved_value.substring(saved_value.length-12)) : value;
} else { } else {
placeholder = value; placeholder = value == null ? "null" : value;
} }
field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`; field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`;
if (Number.isInteger(value)) { if (Number.isInteger(value) && value != 1) {
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="4096" step="1"/><output>${escapeHtml(value)}</output>`; max = value >= 4096 ? 8192 : 4096;
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`; field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
} else if (typeof value == "number") { } else if (typeof value == "number") {
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`; field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
@ -596,10 +597,12 @@ async function load_provider_parameters(provider) {
field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`; field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`; field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
input_el = field_el.querySelector("textarea"); input_el = field_el.querySelector("textarea");
if (value != null) {
input_el.dataset.text = value; input_el.dataset.text = value;
}
input_el.placeholder = placeholder; input_el.placeholder = placeholder;
if (!key in ["api_key", "proof_token"]) { if (!["api_key", "proof_token"].includes(key)) {
input_el.innerHTML = saved_value; input_el.value = saved_value;
} else { } else {
input_el.dataset.saved_value = saved_value; input_el.dataset.saved_value = saved_value;
} }
@ -610,14 +613,16 @@ async function load_provider_parameters(provider) {
}; };
input_el.onfocus = () => { input_el.onfocus = () => {
if (input_el.dataset.saved_value) { if (input_el.dataset.saved_value) {
input_el.innerHTML = input_el.dataset.saved_value; input_el.value = input_el.dataset.saved_value;
} else if (["api_key", "proof_token"].includes(key)) {
input_el.value = input_el.dataset.text;
} }
input_el.style.removeProperty("height"); input_el.style.removeProperty("height");
input_el.style.height = (input_el.scrollHeight) + "px"; input_el.style.height = (input_el.scrollHeight) + "px";
} }
input_el.onblur = () => { input_el.onblur = () => {
input_el.style.removeProperty("height"); input_el.style.removeProperty("height");
if (key in ["api_key", "proof_token"]) { if (["api_key", "proof_token"].includes(key)) {
input_el.value = ""; input_el.value = "";
} }
} }
@ -642,13 +647,14 @@ async function load_provider_parameters(provider) {
input_el.value = input_el.dataset.value; input_el.value = input_el.dataset.value;
input_el.nextElementSibling.value = input_el.dataset.value; input_el.nextElementSibling.value = input_el.dataset.value;
} else if (input_el.dataset.text) { } else if (input_el.dataset.text) {
input_el.innerHTML = input_el.dataset.text; input_el.value = input_el.dataset.text;
} }
delete input_el.dataset.saved_value;
appStorage.removeItem(el_id); appStorage.removeItem(el_id);
field_el.classList.remove("saved"); field_el.classList.remove("saved");
} }
}); });
provider_forms.prepend(form_el); provider_forms.appendChild(form_el);
} }
} }
@ -1004,8 +1010,9 @@ const load_conversation = async (conversation_id, scroll=true) => {
let lines = buffer.trim().split("\n"); let lines = buffer.trim().split("\n");
let lastLine = lines[lines.length - 1]; let lastLine = lines[lines.length - 1];
let newContent = item.content; let newContent = item.content;
if (newContent.startsWith("```\n")) { if (newContent.startsWith("```")) {
newContent = item.content.substring(4); const index = str.indexOf("\n");
newContent = newContent.substring(index);
} }
if (newContent.startsWith(lastLine)) { if (newContent.startsWith(lastLine)) {
newContent = newContent.substring(lastLine.length); newContent = newContent.substring(lastLine.length);
@ -1073,7 +1080,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
reason = "stop" reason = "stop"
} }
} }
if (reason == "max_tokens" || reason == "error") { if (reason == "length" || reason == "max_tokens" || reason == "error") {
actions.push("continue") actions.push("continue")
} }
} }
@ -1405,22 +1412,23 @@ function open_settings() {
const register_settings_storage = async () => { const register_settings_storage = async () => {
const optionElements = document.querySelectorAll(optionElementsSelector); const optionElements = document.querySelectorAll(optionElementsSelector);
optionElements.forEach((element) => { optionElements.forEach((element) => {
element.name = element.name || element.id;
if (element.type == "textarea") { if (element.type == "textarea") {
element.addEventListener('input', async (event) => { element.addEventListener('input', async (event) => {
appStorage.setItem(element.id, element.value); appStorage.setItem(element.name, element.value);
}); });
} else { } else {
element.addEventListener('change', async (event) => { element.addEventListener('change', async (event) => {
switch (element.type) { switch (element.type) {
case "checkbox": case "checkbox":
appStorage.setItem(element.id, element.checked); appStorage.setItem(element.name, element.checked);
break; break;
case "select-one": case "select-one":
appStorage.setItem(element.id, element.selectedIndex); appStorage.setItem(element.name, element.value);
break; break;
case "text": case "text":
case "number": case "number":
appStorage.setItem(element.id, element.value); appStorage.setItem(element.name, element.value);
break; break;
default: default:
console.warn("Unresolved element type"); console.warn("Unresolved element type");
@ -1433,7 +1441,8 @@ const register_settings_storage = async () => {
const load_settings_storage = async () => { const load_settings_storage = async () => {
const optionElements = document.querySelectorAll(optionElementsSelector); const optionElements = document.querySelectorAll(optionElementsSelector);
optionElements.forEach((element) => { optionElements.forEach((element) => {
if (!(value = appStorage.getItem(element.id))) { element.name = element.name || element.id;
if (!(value = appStorage.getItem(element.name))) {
return; return;
} }
if (value) { if (value) {
@ -1442,7 +1451,7 @@ const load_settings_storage = async () => {
element.checked = value === "true"; element.checked = value === "true";
break; break;
case "select-one": case "select-one":
element.selectedIndex = parseInt(value); element.value = value;
break; break;
case "text": case "text":
case "number": case "number":
@ -1683,7 +1692,7 @@ async function on_api() {
console.error(e) console.error(e)
// Redirect to show basic authenfication // Redirect to show basic authenfication
if (document.location.pathname == "/chat/") { if (document.location.pathname == "/chat/") {
document.location.href = `/chat/error`; //document.location.href = `/chat/error`;
} }
} }
register_settings_storage(); register_settings_storage();
@ -1753,7 +1762,7 @@ async function load_version() {
new_version = document.createElement("div"); new_version = document.createElement("div");
new_version.classList.add("new_version"); new_version.classList.add("new_version");
const link = `<a href="${release_url}" target="_blank" title="${title}">v${versions["latest_version"]}</a>`; const link = `<a href="${release_url}" target="_blank" title="${title}">v${versions["latest_version"]}</a>`;
new_version.innerHTML = `g4f ${link}&nbsp;&nbsp;🆕`; new_version.innerHTML = `G4F ${link}&nbsp;&nbsp;🆕`;
new_version.addEventListener("click", ()=>new_version.parentElement.removeChild(new_version)); new_version.addEventListener("click", ()=>new_version.parentElement.removeChild(new_version));
document.body.appendChild(new_version); document.body.appendChild(new_version);
} else { } else {
@ -1951,8 +1960,11 @@ async function api(ressource, args=null, files=null, message_id=null) {
return read_response(response, message_id, args.provider || null); return read_response(response, message_id, args.provider || null);
} }
response = await fetch(url, {headers: headers}); response = await fetch(url, {headers: headers});
if (response.status == 200) {
return await response.json(); return await response.json();
} }
console.error(response);
}
async function read_response(response, message_id, provider) { async function read_response(response, message_id, provider) {
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
@ -1987,19 +1999,19 @@ function get_api_key_by_provider(provider) {
return api_key; return api_key;
} }
async function load_provider_models(providerIndex=null) { async function load_provider_models(provider=null) {
if (!providerIndex) { if (!provider) {
providerIndex = providerSelect.selectedIndex; provider = providerSelect.value;
} }
modelProvider.innerHTML = ''; modelProvider.innerHTML = '';
const provider = providerSelect.options[providerIndex].value; modelProvider.name = `model[${provider}]`;
if (!provider) { if (!provider) {
modelProvider.classList.add("hidden"); modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden"); modelSelect.classList.remove("hidden");
return; return;
} }
const models = await api('models', provider); const models = await api('models', provider);
if (models.length > 0) { if (models && models.length > 0) {
modelSelect.classList.add("hidden"); modelSelect.classList.add("hidden");
modelProvider.classList.remove("hidden"); modelProvider.classList.remove("hidden");
models.forEach((model) => { models.forEach((model) => {
@ -2010,6 +2022,10 @@ async function load_provider_models(providerIndex=null) {
option.selected = model.default; option.selected = model.default;
modelProvider.appendChild(option); modelProvider.appendChild(option);
}); });
let value = appStorage.getItem(modelProvider.name);
if (value) {
modelProvider.value = value;
}
} else { } else {
modelProvider.classList.add("hidden"); modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden"); modelSelect.classList.remove("hidden");

View file

@ -7,7 +7,7 @@ class CopyButtonPlugin {
el, el,
text text
}) { }) {
if (el.classList.contains("language-plaintext")) { if (el.parentElement.tagName != "PRE") {
return; return;
} }
let button = Object.assign(document.createElement("button"), { let button = Object.assign(document.createElement("button"), {

View file

@ -64,7 +64,6 @@ class Api:
"parent": getattr(provider, "parent", None), "parent": getattr(provider, "parent", None),
"image": getattr(provider, "image_models", None) is not None, "image": getattr(provider, "image_models", None) is not None,
"vision": getattr(provider, "default_vision_model", None) is not None, "vision": getattr(provider, "default_vision_model", None) is not None,
"webdriver": "webdriver" in provider.get_parameters(),
"auth": provider.needs_auth, "auth": provider.needs_auth,
} for provider in __providers__ if provider.working] } for provider in __providers__ if provider.working]

View file

@ -6,4 +6,6 @@ def create_app() -> Flask:
template_folder = os.path.join(sys._MEIPASS, "client") template_folder = os.path.join(sys._MEIPASS, "client")
else: else:
template_folder = "../client" template_folder = "../client"
return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static") app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
app.config["TEMPLATES_AUTO_RELOAD"] = True # Enable auto reload in debug mode
return app

View file

@ -14,9 +14,12 @@ from werkzeug.utils import secure_filename
from ...image import is_allowed_extension, to_image from ...image import is_allowed_extension, to_image
from ...client.service import convert_to_provider from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError from ...errors import ProviderNotFoundError
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ... import ChatCompletion
from .api import Api from .api import Api
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -101,6 +104,44 @@ class Backend_Api(Api):
} }
} }
@app.route('/backend-api/v2/create', methods=['GET', 'POST'])
def create():
try:
tool_calls = [{
"function": {
"name": "bucket_tool"
},
"type": "function"
}]
web_search = request.args.get("web_search")
if web_search:
tool_calls.append({
"function": {
"name": "search_tool",
"arguments": {"query": web_search, "instructions": ""} if web_search != "true" else {}
},
"type": "function"
})
do_filter_markdown = request.args.get("filter_markdown")
response = iter_run_tools(
ChatCompletion.create,
model=request.args.get("model"),
messages=[{"role": "user", "content": request.args.get("prompt")}],
provider=request.args.get("provider", None),
stream=not do_filter_markdown,
ignore_stream=not request.args.get("stream"),
tool_calls=tool_calls,
)
if do_filter_markdown:
return Response(filter_markdown(response, do_filter_markdown), mimetype='text/plain')
def cast_str():
for chunk in response:
yield str(chunk)
return Response(cast_str(), mimetype='text/plain')
except Exception as e:
logger.exception(e)
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
@app.route('/backend-api/v2/buckets', methods=['GET']) @app.route('/backend-api/v2/buckets', methods=['GET'])
def list_buckets(): def list_buckets():
try: try:

View file

@ -9,7 +9,7 @@ class Website:
self.app = app self.app = app
self.routes = { self.routes = {
'/': { '/': {
'function': redirect_home, 'function': self._home,
'methods': ['GET', 'POST'] 'methods': ['GET', 'POST']
}, },
'/chat/': { '/chat/': {
@ -41,3 +41,6 @@ class Website:
def _index(self): def _index(self):
return render_template('index.html', chat_id=str(uuid.uuid4())) return render_template('index.html', chat_id=str(uuid.uuid4()))
def _home(self):
return render_template('home.html')

View file

@ -98,26 +98,26 @@ gpt_35_turbo = Model(
gpt_4 = Model( gpt_4 = Model(
name = 'gpt-4', name = 'gpt-4',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = IterListProvider([DDG, Blackbox, ChatGptEs, PollinationsAI, Copilot, OpenaiChat, Liaobots, Airforce, Mhystical]) best_provider = IterListProvider([DDG, Blackbox, ChatGptEs, PollinationsAI, Copilot, OpenaiChat, Liaobots, Mhystical])
) )
gpt_4_turbo = Model( gpt_4_turbo = Model(
name = 'gpt-4-turbo', name = 'gpt-4-turbo',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = Airforce best_provider = None
) )
# gpt-4o # gpt-4o
gpt_4o = Model( gpt_4o = Model(
name = 'gpt-4o', name = 'gpt-4o',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = IterListProvider([Blackbox, ChatGptEs, PollinationsAI, DarkAI, ChatGpt, Airforce, Liaobots, OpenaiChat]) best_provider = IterListProvider([Blackbox, ChatGptEs, PollinationsAI, DarkAI, ChatGpt, Liaobots, OpenaiChat])
) )
gpt_4o_mini = Model( gpt_4o_mini = Model(
name = 'gpt-4o-mini', name = 'gpt-4o-mini',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = IterListProvider([DDG, ChatGptEs, Pizzagpt, ChatGpt, Airforce, RubiksAI, Liaobots, OpenaiChat]) best_provider = IterListProvider([DDG, ChatGptEs, Pizzagpt, ChatGpt, RubiksAI, Liaobots, OpenaiChat])
) )
# o1 # o1
@ -136,7 +136,7 @@ o1_preview = Model(
o1_mini = Model( o1_mini = Model(
name = 'o1-mini', name = 'o1-mini',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = IterListProvider([Liaobots, Airforce]) best_provider = IterListProvider([Liaobots])
) )
### GigaChat ### ### GigaChat ###

View file

@ -5,8 +5,10 @@ import asyncio
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod from abc import abstractmethod
import json
from inspect import signature, Parameter from inspect import signature, Parameter
from typing import Optional, _GenericAlias from typing import Optional, Awaitable, _GenericAlias
from pathlib import Path
try: try:
from types import NoneType from types import NoneType
except ImportError: except ImportError:
@ -15,9 +17,10 @@ except ImportError:
from ..typing import CreateResult, AsyncResult, Messages from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider from .types import BaseProvider
from .asyncio import get_running_loop, to_sync_generator from .asyncio import get_running_loop, to_sync_generator
from .response import BaseConversation from .response import BaseConversation, AuthResult
from .helper import concat_chunks, async_concat_chunks from .helper import concat_chunks, async_concat_chunks
from ..errors import ModelNotSupportedError from ..cookies import get_cookies_dir
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
from .. import debug from .. import debug
SAFE_PARAMETERS = [ SAFE_PARAMETERS = [
@ -33,15 +36,14 @@ SAFE_PARAMETERS = [
] ]
BASIC_PARAMETERS = { BASIC_PARAMETERS = {
"provider": None,
"model": "", "model": "",
"messages": [], "messages": [],
"provider": None,
"stream": False, "stream": False,
"timeout": 0, "timeout": 0,
"response_format": None, "response_format": None,
"max_tokens": None, "max_tokens": None,
"stop": None, "stop": None,
"web_search": False,
} }
PARAMETER_EXAMPLES = { PARAMETER_EXAMPLES = {
@ -109,12 +111,8 @@ class AbstractProvider(BaseProvider):
).parameters.items() if name in SAFE_PARAMETERS ).parameters.items() if name in SAFE_PARAMETERS
and (name != "stream" or cls.supports_stream)} and (name != "stream" or cls.supports_stream)}
if as_json: if as_json:
def get_type_as_var(annotation: type, key: str): def get_type_as_var(annotation: type, key: str, default):
if key == "model": if key in PARAMETER_EXAMPLES:
return getattr(cls, "default_model", "")
elif key == "stream":
return cls.supports_stream
elif key in PARAMETER_EXAMPLES:
if key == "messages" and not cls.supports_system_message: if key == "messages" and not cls.supports_system_message:
return [PARAMETER_EXAMPLES[key][-1]] return [PARAMETER_EXAMPLES[key][-1]]
return PARAMETER_EXAMPLES[key] return PARAMETER_EXAMPLES[key]
@ -137,18 +135,21 @@ class AbstractProvider(BaseProvider):
return {} return {}
elif annotation is None: elif annotation is None:
return None return None
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ is Optional: elif annotation == "str" or annotation == "list[str]":
return default
elif isinstance(annotation, _GenericAlias):
if annotation.__origin__ is Optional:
return get_type_as_var(annotation.__args__[0]) return get_type_as_var(annotation.__args__[0])
else: else:
return str(annotation) return str(annotation)
return { name: ( return { name: (
param.default param.default
if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
else get_type_as_var(param.annotation if isinstance(param, Parameter) else type(param), name) else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
) for name, param in { ) for name, param in {
**BASIC_PARAMETERS, **BASIC_PARAMETERS,
**{"provider": cls.__name__}, **params,
**params **{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
}.items()} }.items()}
return params return params
@ -310,6 +311,12 @@ class AsyncGeneratorProvider(AsyncProvider):
""" """
raise NotImplementedError() raise NotImplementedError()
create_authed = create_completion
create_authed_async = create_async
create_async_authed = create_async_generator
class ProviderModelMixin: class ProviderModelMixin:
default_model: str = None default_model: str = None
models: list[str] = [] models: list[str] = []
@ -335,3 +342,112 @@ class ProviderModelMixin:
cls.last_model = model cls.last_model = model
debug.last_model = model debug.last_model = model
return model return model
class RaiseErrorMixin():
@staticmethod
def raise_error(data: dict):
if "error_message" in data:
raise ResponseError(data["error_message"])
elif "error" in data:
if "code" in data["error"]:
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
elif "message" in data["error"]:
raise ResponseError(data["error"]["message"])
else:
raise ResponseError(data["error"])
class AuthedMixin():
@classmethod
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
if "api_key" not in kwargs:
raise MissingAuthError(f"API key is required for {cls.__name__}")
return None
@classmethod
def create_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
auth_result = {}
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_completion(model, messages, **kwargs)
finally:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedMixin(AuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await cls.create_async(model, messages, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
@classmethod
def create_async_authed_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> Awaitable[AsyncResult]:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))

View file

@ -105,6 +105,10 @@ class Usage(ResponseType, JsonMixin):
def __str__(self) -> str: def __str__(self) -> str:
return "" return ""
class AuthResult(JsonMixin):
def __str__(self) -> str:
return ""
class TitleGeneration(ResponseType): class TitleGeneration(ResponseType):
def __init__(self, title: str) -> None: def __init__(self, title: str) -> None:
self.title = title self.title = title
@ -182,4 +186,5 @@ class ImagePreview(ImageResponse):
return super().__str__() return super().__str__()
class Parameters(ResponseType, JsonMixin): class Parameters(ResponseType, JsonMixin):
pass def __str__(self):
return ""

View file

@ -40,7 +40,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
) )
elif tool.get("function", {}).get("name") == "continue": elif tool.get("function", {}).get("name") == "continue":
last_line = messages[-1]["content"].strip().splitlines()[-1] last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue after this line.\n{last_line}" content = f"Carry on from this point:\n{last_line}"
messages.append({"role": "user", "content": content}) messages.append({"role": "user", "content": content})
elif tool.get("function", {}).get("name") == "bucket_tool": elif tool.get("function", {}).get("name") == "bucket_tool":
def on_bucket(match): def on_bucket(match):
@ -90,7 +90,7 @@ def iter_run_tools(
elif tool.get("function", {}).get("name") == "continue_tool": elif tool.get("function", {}).get("name") == "continue_tool":
if provider not in ("OpenaiAccount", "HuggingFace"): if provider not in ("OpenaiAccount", "HuggingFace"):
last_line = messages[-1]["content"].strip().splitlines()[-1] last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue after this line:\n{last_line}" content = f"Carry on from this point:\n{last_line}"
messages.append({"role": "user", "content": content}) messages.append({"role": "user", "content": content})
else: else:
# Enable provider native continue # Enable provider native continue

View file

@ -5,6 +5,7 @@ import json
import hashlib import hashlib
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
from datetime import datetime
import datetime import datetime
import asyncio import asyncio
@ -65,7 +66,7 @@ class SearchResultEntry():
def set_text(self, text: str): def set_text(self, text: str):
self.text = text self.text = text
def scrape_text(html: str, max_words: int = None) -> Iterator[str]: def scrape_text(html: str, max_words: int = None, add_source=True) -> Iterator[str]:
source = BeautifulSoup(html, "html.parser") source = BeautifulSoup(html, "html.parser")
soup = source soup = source
for selector in [ for selector in [
@ -87,7 +88,7 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
if select: if select:
select.extract() select.extract()
for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"): for paragraph in soup.select("p, table:not(:has(p)), ul:not(:has(p)), h1, h2, h3, h4, h5, h6"):
for line in paragraph.text.splitlines(): for line in paragraph.text.splitlines():
words = [word for word in line.replace("\t", " ").split(" ") if word] words = [word for word in line.replace("\t", " ").split(" ") if word]
count = len(words) count = len(words)
@ -99,24 +100,25 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
break break
yield " ".join(words) + "\n" yield " ".join(words) + "\n"
if add_source:
canonical_link = source.find("link", rel="canonical") canonical_link = source.find("link", rel="canonical")
if canonical_link and "href" in canonical_link.attrs: if canonical_link and "href" in canonical_link.attrs:
link = canonical_link["href"] link = canonical_link["href"]
domain = urlparse(link).netloc domain = urlparse(link).netloc
yield f"\nSource: [{domain}]({link})" yield f"\nSource: [{domain}]({link})"
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str: async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None, add_source: bool = False) -> str:
try: try:
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape" bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
bucket_dir.mkdir(parents=True, exist_ok=True) bucket_dir.mkdir(parents=True, exist_ok=True)
md5_hash = hashlib.md5(url.encode()).hexdigest() md5_hash = hashlib.md5(url.encode()).hexdigest()
cache_file = bucket_dir / f"{url.split('/')[3]}.{datetime.date.today()}.{md5_hash}.txt" cache_file = bucket_dir / f"{url.split('?')[0].split('//')[1].replace('/', '+')[:16]}.{datetime.date.today()}.{md5_hash}.txt"
if cache_file.exists(): if cache_file.exists():
return cache_file.read_text() return cache_file.read_text()
async with session.get(url) as response: async with session.get(url) as response:
if response.status == 200: if response.status == 200:
html = await response.text() html = await response.text()
text = "".join(scrape_text(html, max_words)) text = "".join(scrape_text(html, max_words, add_source))
with open(cache_file, "w") as f: with open(cache_file, "w") as f:
f.write(text) f.write(text)
return text return text
@ -136,6 +138,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
max_results=max_results, max_results=max_results,
backend=backend, backend=backend,
): ):
if ".google.com" in result["href"]:
continue
results.append(SearchResultEntry( results.append(SearchResultEntry(
result["title"], result["title"],
result["href"], result["href"],
@ -146,7 +150,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
requests = [] requests = []
async with ClientSession(timeout=ClientTimeout(timeout)) as session: async with ClientSession(timeout=ClientTimeout(timeout)) as session:
for entry in results: for entry in results:
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)))) requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)), False))
texts = await asyncio.gather(*requests) texts = await asyncio.gather(*requests)
formatted_results = [] formatted_results = []
@ -173,7 +177,7 @@ async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_
query = spacy_get_keywords(prompt) query = spacy_get_keywords(prompt)
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode() json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode()
md5_hash = hashlib.md5(json_bytes).hexdigest() md5_hash = hashlib.md5(json_bytes).hexdigest()
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search" bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search:{datetime.date.today()}"
bucket_dir.mkdir(parents=True, exist_ok=True) bucket_dir.mkdir(parents=True, exist_ok=True)
cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt" cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt"
if cache_file.exists(): if cache_file.exists():
@ -192,7 +196,9 @@ Instruction: {instructions}
User request: User request:
{prompt} {prompt}
""" """
debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words") debug.log(f"Web search: '{query.strip()[:50]}...'")
if isinstance(search_results, SearchResults):
debug.log(f"with {len(search_results.results)} Results {search_results.used_words} Words")
return new_prompt return new_prompt
def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str: def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str: