Merge pull request #2530 from hlohaus/cont

Add Anthropic provider
This commit is contained in:
H Lohaus 2025-01-03 02:56:06 +01:00 committed by GitHub
commit c5ba78c7e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 800 additions and 161 deletions

View file

@ -70,8 +70,7 @@ def read_json(text: str) -> dict:
try:
return json.loads(text.strip())
except json.JSONDecodeError:
print("No valid json:", text)
return {}
raise RuntimeError(f"Invalid JSON: {text}")
def read_text(text: str) -> str:
"""
@ -86,7 +85,8 @@ def read_text(text: str) -> str:
match = re.search(r"```(markdown|)\n(?P<text>[\S\s]+?)\n```", text)
if match:
return match.group("text")
return text
else:
raise RuntimeError(f"Invalid markdown: {text}")
def get_ai_response(prompt: str, as_json: bool = True) -> Union[dict, str]:
"""
@ -197,6 +197,7 @@ def create_review_prompt(pull: PullRequest, diff: str):
return f"""Your task is to review a pull request. Instructions:
- Write in name of g4f copilot. Don't use placeholder.
- Write the review in GitHub Markdown format.
- Enclose your response in backticks ```response```
- Thank the author for contributing to the project.
Pull request author: {pull.user.name}

View file

@ -16,8 +16,6 @@ async def test_async(provider: ProviderType):
return False
messages = [{"role": "user", "content": "Hello Assistant!"}]
try:
if "webdriver" in provider.get_parameters():
return False
response = await asyncio.wait_for(ChatCompletion.create_async(
model=models.default,
messages=messages,

View file

@ -46,4 +46,4 @@ class TestBackendApi(unittest.TestCase):
self.skipTest(e)
except MissingRequirementsError:
self.skipTest("search is not installed")
self.assertTrue(len(result) >= 4)
self.assertGreater(len(result), 0)

View file

@ -1,36 +1,11 @@
import unittest
import asyncio
import g4f
from g4f import ChatCompletion, get_last_provider
import g4f.version
from g4f.errors import VersionNotFoundError
from g4f.Provider import RetryProvider
from .mocks import ProviderMock
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestGetLastProvider(unittest.TestCase):
def test_get_last_provider(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_retry(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_async(self):
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
asyncio.run(coroutine)
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_as_dict(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
last_provider_dict = get_last_provider(True)
self.assertIsInstance(last_provider_dict, dict)
self.assertIn('name', last_provider_dict)
self.assertEqual(ProviderMock.__name__, last_provider_dict['name'])
def test_get_latest_version(self):
try:
self.assertIsInstance(g4f.version.utils.current_version, str)

View file

@ -35,7 +35,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
default_model = "gpt-4o-mini"
default_model = "llama-3.1-70b-chat"
default_image_model = "flux"
models = []
@ -113,7 +113,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
def get_model(cls, model: str) -> str:
"""Get the actual model name from alias"""
return cls.model_aliases.get(model, model)
return cls.model_aliases.get(model, model or cls.default_model)
@classmethod
async def check_api_key(cls, api_key: str) -> bool:
@ -162,6 +162,9 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
"""
Filters the full response to remove system errors and other unwanted text.
"""
if "Model not found or too long input. Or any other error (xD)" in response:
raise ValueError(response)
filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored
filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token
filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1

View file

@ -246,6 +246,8 @@ class BlackboxCreateAgent(AsyncGeneratorProvider, ProviderModelMixin):
Returns:
AsyncResult: The response from the provider
"""
if not model:
model = cls.default_model
if model in cls.chat_models:
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
return text

View file

@ -80,4 +80,6 @@ class ChatGptEs(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(cls.api_endpoint, headers=headers, data=payload) as response:
response.raise_for_status()
result = await response.json()
if "Du musst das Kästchen anklicken!" in result['data']:
raise ValueError(result['data'])
yield result['data']

View file

@ -127,8 +127,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
response = session.post(cls.conversation_url)
raise_for_status(response)
conversation_id = response.json().get("id")
conversation = Conversation(conversation_id)
if return_conversation:
conversation = Conversation(conversation_id)
yield conversation
if prompt is None:
prompt = format_prompt_max_length(messages, 10000)

View file

@ -15,7 +15,8 @@ from .needs_auth.OpenaiAPI import OpenaiAPI
"""
class Mhystical(OpenaiAPI):
url = "https://api.mhystical.cc"
label = "Mhystical"
url = "https://mhystical.cc"
api_endpoint = "https://api.mhystical.cc/v1/completions"
working = True
needs_auth = False

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import json
import random
import requests
from urllib.parse import quote
from typing import Optional
from aiohttp import ClientSession
@ -170,7 +171,7 @@ class PollinationsAI(OpenaiAPI):
params = {k: v for k, v in params.items() if v is not None}
async with ClientSession(headers=headers) as session:
prompt = quote(messages[-1]["content"])
prompt = quote(messages[-1]["content"] if prompt is None else prompt)
param_string = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}"

View file

@ -0,0 +1,195 @@
from __future__ import annotations
import requests
import json
import base64
from typing import Optional
from ..helper import filter_none
from ...typing import AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage
from ...errors import MissingAuthError
from ...image import to_bytes, is_accepted_format
from .OpenaiAPI import OpenaiAPI
class Anthropic(OpenaiAPI):
label = "Anthropic API"
url = "https://console.anthropic.com"
login_url = "https://console.anthropic.com/settings/keys"
working = True
api_base = "https://api.anthropic.com/v1"
needs_auth = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "claude-3-5-sonnet-latest"
models = [
default_model,
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-latest",
"claude-3-5-haiku-20241022",
"claude-3-opus-latest",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
]
models_aliases = {
"claude-3.5-sonnet": default_model,
"claude-3-opus": "claude-3-opus-latest",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
}
@classmethod
def get_models(cls, api_key: str = None, **kwargs):
if not cls.models:
url = f"https://api.anthropic.com/v1/models"
response = requests.get(url, headers={
"Content-Type": "application/json",
"x-api-key": api_key,
"anthropic-version": "2023-06-01"
})
raise_for_status(response)
models = response.json()
cls.models = [model["id"] for model in models["data"]]
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
images: ImagesType = None,
api_key: str = None,
temperature: float = None,
max_tokens: int = 4096,
top_k: int = None,
top_p: float = None,
stop: list[str] = None,
stream: bool = False,
headers: dict = None,
impersonate: str = None,
tools: Optional[list] = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')
if images is not None:
insert_images = []
for image, _ in images:
data = to_bytes(image)
insert_images.append({
"type": "image",
"source": {
"type": "base64",
"media_type": is_accepted_format(data),
"data": base64.b64encode(data).decode(),
}
})
messages[-1]["content"] = [
*insert_images,
{
"type": "text",
"text": messages[-1]["content"]
}
]
system = "\n".join([message for message in messages if message.get("role") == "system"])
if system:
messages = [message for message in messages if message.get("role") != "system"]
else:
system = None
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
timeout=timeout,
impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model, api_key=api_key),
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
stop_sequences=stop,
system=system,
stream=stream,
tools=tools,
**extra_data
)
async with session.post(f"{cls.api_base}/messages", json=data) as response:
await raise_for_status(response)
if not stream:
data = await response.json()
cls.raise_error(data)
if "type" in data and data["type"] == "message":
for content in data["content"]:
if content["type"] == "text":
yield content["text"]
elif content["type"] == "tool_use":
tool_calls.append({
"id": content["id"],
"type": "function",
"function": { "name": content["name"], "arguments": content["input"] }
})
if data["stop_reason"] == "end_turn":
yield FinishReason("stop")
elif data["stop_reason"] == "max_tokens":
yield FinishReason("length")
yield Usage(**data["usage"])
else:
content_block = None
partial_json = []
tool_calls = []
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
data = json.loads(chunk)
cls.raise_error(data)
if "type" in data:
if data["type"] == "content_block_start":
content_block = data["content_block"]
if content_block is None:
pass # Message start
elif data["type"] == "content_block_delta":
if content_block["type"] == "text":
yield data["delta"]["text"]
elif content_block["type"] == "tool_use":
partial_json.append(data["delta"]["partial_json"])
elif data["type"] == "message_delta":
if data["delta"]["stop_reason"] == "end_turn":
yield FinishReason("stop")
elif data["delta"]["stop_reason"] == "max_tokens":
yield FinishReason("length")
yield Usage(**data["usage"])
elif data["type"] == "content_block_stop":
if content_block["type"] == "tool_use":
tool_calls.append({
"id": content_block["id"],
"type": "function",
"function": { "name": content_block["name"], "arguments": partial_json.join("") }
})
partial_json = []
if tool_calls:
yield ToolCalls(tool_calls)
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
"Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
**(
{"x-api-key": api_key}
if api_key is not None else {}
),
"anthropic-version": "2023-06-01",
**({} if headers is None else headers)
}

View file

@ -10,6 +10,7 @@ from ...cookies import get_cookies
class Cerebras(OpenaiAPI):
label = "Cerebras Inference"
url = "https://inference.cerebras.ai/"
login_url = "https://cloud.cerebras.ai"
api_base = "https://api.cerebras.ai/v1"
working = True
default_model = "llama3.1-70b"

View file

@ -16,6 +16,7 @@ from ... import debug
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API"
url = "https://ai.google.dev"
login_url = "https://aistudio.google.com/u/0/apikey"
api_base = "https://generativelanguage.googleapis.com/v1beta"
working = True
@ -24,7 +25,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
default_model = "gemini-1.5-pro"
default_vision_model = default_model
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"gemini-flash": "gemini-1.5-flash-8b",

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class Groq(OpenaiAPI):
label = "Groq"
url = "https://console.groq.com/playground"
login_url = "https://console.groq.com/keys"
api_base = "https://api.groq.com/openai/v1"
working = True
default_model = "mixtral-8x7b-32768"

View file

@ -47,8 +47,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None,
api_base: str = "https://api-inference.huggingface.co",
api_key: str = None,
max_new_tokens: int = 1024,
temperature: float = 0.7,
max_tokens: int = 1024,
temperature: float = None,
prompt: str = None,
action: str = None,
extra_data: dict = {},
@ -84,7 +84,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
else:
params = {
"return_full_text": False,
"max_new_tokens": max_new_tokens,
"max_new_tokens": max_tokens,
"temperature": temperature,
**extra_data
}

View file

@ -6,6 +6,7 @@ from .HuggingChat import HuggingChat
class HuggingFaceAPI(OpenaiAPI):
label = "HuggingFace (Inference API)"
url = "https://api-inference.huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
api_base = "https://api-inference.huggingface.co/v1"
working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"

View file

@ -4,7 +4,7 @@ import json
import requests
from ..helper import filter_none
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage
@ -12,9 +12,10 @@ from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri
from ... import debug
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
label = "OpenAI API"
url = "https://platform.openai.com"
login_url = "https://platform.openai.com/settings/organization/api-keys"
api_base = "https://api.openai.com/v1"
working = True
needs_auth = True
@ -141,18 +142,6 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
if "finish_reason" in choice and choice["finish_reason"] is not None:
return FinishReason(choice["finish_reason"])
@staticmethod
def raise_error(data: dict):
if "error_message" in data:
raise ResponseError(data["error_message"])
elif "error" in data:
if "code" in data["error"]:
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
elif "message" in data["error"]:
raise ResponseError(data["error"]["message"])
else:
raise ResponseError(data["error"])
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class PerplexityApi(OpenaiAPI):
label = "Perplexity API"
url = "https://www.perplexity.ai"
login_url = "https://www.perplexity.ai/settings/api"
working = True
api_base = "https://api.perplexity.ai"
default_model = "llama-3-sonar-large-32k-online"

View file

@ -9,6 +9,7 @@ from ...errors import ResponseError, MissingAuthError
class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com"
login_url = "https://replicate.com/account/api-tokens"
working = True
needs_auth = True
default_model = "meta/meta-llama-3-70b-instruct"
@ -25,7 +26,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None,
timeout: int = 180,
system_prompt: str = None,
max_new_tokens: int = None,
max_tokens: int = None,
temperature: float = None,
top_p: float = None,
top_k: float = None,
@ -55,7 +56,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
"prompt": format_prompt(messages),
**filter_none(
system_prompt=system_prompt,
max_new_tokens=max_new_tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,

View file

@ -1,5 +1,6 @@
from .gigachat import *
from .Anthropic import Anthropic
from .BingCreateImages import BingCreateImages
from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class glhfChat(OpenaiAPI):
label = "glhf.chat"
url = "https://glhf.chat"
login_url = "https://glhf.chat/users/settings/api"
api_base = "https://glhf.chat/api/openai/v1"
working = True
model_aliases = {

View file

@ -5,5 +5,6 @@ from .OpenaiAPI import OpenaiAPI
class xAI(OpenaiAPI):
label = "xAI"
url = "https://console.x.ai"
login_url = "https://console.x.ai"
api_base = "https://api.x.ai/v1"
working = True

View file

@ -12,7 +12,7 @@ from .errors import StreamNotSupportedError
from .cookies import get_cookies, set_cookies
from .providers.types import ProviderType
from .providers.helper import concat_chunks
from .client.service import get_model_and_provider, get_last_provider
from .client.service import get_model_and_provider
#Configure "g4f" logger
logger = logging.getLogger(__name__)
@ -47,7 +47,8 @@ class ChatCompletion:
if ignore_stream:
kwargs["ignore_stream"] = True
result = provider.create_completion(model, messages, stream=stream, **kwargs)
create_method = provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion
result = create_method(model, messages, stream=stream, **kwargs)
return result if stream else concat_chunks(result)
@ -72,7 +73,9 @@ class ChatCompletion:
kwargs["ignore_stream"] = True
if stream:
if hasattr(provider, "create_async_generator"):
if hasattr(provider, "create_async_authed_generator"):
return provider.create_async_authed_generator(model, messages, **kwargs)
elif hasattr(provider, "create_async_generator"):
return provider.create_async_generator(model, messages, **kwargs)
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')

View file

@ -197,11 +197,10 @@ class Api:
)
def register_routes(self):
@self.app.get("/")
async def read_root():
if AppConfig.gui:
return RedirectResponse("/chat/", 302)
return RedirectResponse("/v1", 302)
if not AppConfig.gui:
@self.app.get("/")
async def read_root():
return RedirectResponse("/v1", 302)
@self.app.get("/v1")
async def read_root_v1():

View file

@ -7,9 +7,9 @@ import string
import asyncio
import aiohttp
import base64
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image import ImageResponse, copy_images, images_dir
from ..image import ImageResponse, copy_images
from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
@ -22,7 +22,7 @@ from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
from .. import debug
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
@ -220,7 +220,7 @@ class Completions:
ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False,
**kwargs
) -> IterResponse:
) -> ChatCompletion:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@ -236,7 +236,7 @@ class Completions:
kwargs["ignore_stream"] = True
response = iter_run_tools(
provider.create_completion,
provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion,
model,
messages,
stream=stream,
@ -248,9 +248,6 @@ class Completions:
),
**kwargs
)
if asyncio.iscoroutinefunction(provider.create_completion):
# Run the asynchronous function in an event loop
response = asyncio.run(response)
if stream and hasattr(response, '__aiter__'):
# It's an async generator, wrap it into a sync iterator
response = to_sync_generator(response)
@ -264,6 +261,14 @@ class Completions:
else:
return next(response)
def stream(
self,
messages: Messages,
model: str,
**kwargs
) -> IterResponse:
return self.create(messages, model, stream=True, **kwargs)
class Chat:
completions: Completions
@ -507,7 +512,7 @@ class AsyncCompletions:
ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False,
**kwargs
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
) -> Awaitable[ChatCompletion]:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@ -521,6 +526,8 @@ class AsyncCompletions:
kwargs["images"] = [(image, image_name)]
if ignore_stream:
kwargs["ignore_stream"] = True
if hasattr(provider, "create_async_authed_generator"):
create_handler = provider.create_async_authed_generator
if hasattr(provider, "create_async_generator"):
create_handler = provider.create_async_generator
else:
@ -538,10 +545,20 @@ class AsyncCompletions:
),
**kwargs
)
if not hasattr(response, '__aiter__'):
response = to_async_iterator(response)
response = async_iter_response(response, stream, response_format, max_tokens, stop)
response = async_iter_append_model_and_provider(response, model, provider)
return response if stream else anext(response)
def stream(
self,
messages: Messages,
model: str,
**kwargs
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
return self.create(messages, model, stream=True, **kwargs)
class AsyncImages(Images):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client

View file

@ -5,6 +5,22 @@ import logging
from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
def filter_markdown(text: str, allowd_types=None, default=None) -> str:
"""
Parses code block from a string.
Args:
text (str): A string containing a code block.
Returns:
dict: A dictionary parsed from the code block.
"""
match = re.search(r"```(.+)\n(?P<code>[\S\s]+?)(\n```|$)", text)
if match:
if allowd_types is None or match.group(1) in allowd_types:
return match.group("code")
return default
def filter_json(text: str) -> str:
"""
Parses JSON code block from a string.
@ -15,10 +31,7 @@ def filter_json(text: str) -> str:
Returns:
dict: A dictionary parsed from the JSON code block.
"""
match = re.search(r"```(json|)\n(?P<code>[\S\s]+?)\n```", text)
if match:
return match.group("code")
return text
return filter_markdown(text, ["", "json"], text)
def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
first = -1

230
g4f/gui/client/home.html Normal file
View file

@ -0,0 +1,230 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>G4F GUI</title>
<style>
:root {
--colour-1: #000000;
--colour-2: #ccc;
--colour-3: #e4d4ff;
--colour-4: #f0f0f0;
--colour-5: #181818;
--colour-6: #242424;
--accent: #8b3dff;
--gradient: #1a1a1a;
--background: #16101b;
--size: 70vw;
--top: 50%;
--blur: 40px;
--opacity: 0.6;
}
@import url("https://fonts.googleapis.com/css2?family=Inter:wght@100;200;300;400;500;600;700;800;900&display=swap");
.gradient {
position: absolute;
z-index: -1;
left: 50vw;
border-radius: 50%;
background: radial-gradient(circle at center, var(--accent), var(--gradient));
width: var(--size);
height: var(--size);
top: var(--top);
transform: translate(-50%, -50%);
filter: blur(var(--blur)) opacity(var(--opacity));
animation: zoom_gradient 6s infinite alternate;
display: none;
max-height: 100%;
transition: max-height 0.25s ease-in;
}
.gradient.hidden {
max-height: 0;
transition: max-height 0.15s ease-out;
}
@media only screen and (min-width: 40em) {
body .gradient{
display: block;
}
}
@keyframes zoom_gradient {
0% {
transform: translate(-50%, -50%) scale(1);
}
100% {
transform: translate(-50%, -50%) scale(1.2);
}
}
/* Body and text color */
body {
background: var(--background);
color: var(--colour-3);
font-family: "Inter", sans-serif;
height: 100vh;
margin: 0;
padding: 0;
overflow: hidden;
font-weight: bold;
}
/* Container for the main content */
.container {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 100%;
text-align: center;
z-index: 1;
}
header {
font-size: 3rem;
text-transform: uppercase;
margin: 20px;
color: var(--colour-4);
}
iframe {
background: transparent;
width: 100%;
border: none;
}
#background {
height: 100%;
position: absolute;
z-index: -1;
}
iframe.stream {
max-height: 0;
transition: max-height 0.15s ease-out;
}
iframe.stream.show {
max-height: 1000px;
height: 1000px;
transition: max-height 0.25s ease-in;
background: rgba(255,255,255,0.7);
border-top: 2px solid rgba(255,255,255,0.5);
}
.description {
font-size: 1.2rem;
margin-bottom: 30px;
color: var(--colour-2);
} return app
.input-field {
width: 80%;
max-width: 400px;
padding: 12px;
margin: 10px 0;
border: 2px solid var(--colour-6);
background-color: var(--colour-5);
color: var(--colour-3);
border-radius: 8px;
font-size: 1.1rem;
}
.input-field:focus {
outline: none;
border-color: var(--accent);
}
.button {
background-color: var(--accent);
color: var(--colour-3);
border: none;
padding: 15px 30px;
font-size: 1.1rem;
border-radius: 8px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-top: 15px;
width: 100%;
max-width: 400px;
font-weight: bold;
}
.button:hover {
background-color: #7a2ccd;
}
.footer {
margin-top: 30px;
font-size: 0.9rem;
color: var(--colour-2);
}
/* Animation for the gradient circle */
@keyframes zoom_gradient {
0% {
transform: translate(-50%, -50%) scale(1);
}
100% {
transform: translate(-50%, -50%) scale(1.5);
}
}
</style>
</head>
<body>
<iframe id="background"></iframe>
<!-- Gradient Background Circle -->
<div class="gradient"></div>
<!-- Main Content -->
<div class="container">
<header>
G4F GUI
</header>
<div class="description">
Welcome to the G4F GUI! <br>
Your AI assistant is ready to assist you.
</div>
<!-- Input and Button -->
<form action="/chat/">
<!--
<input type="text" name="prompt" class="input-field" placeholder="Enter your query...">
-->
<button class="button">Open Chat</button>
</form>
<!-- Footer -->
<div class="footer">
<p>&copy; 2025 G4F. All Rights Reserved.</p>
<p>Powered by the G4F framework</p>
</div>
<iframe id="stream-widget" class="stream" data-src="/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=news in " class="" frameborder="0"></iframe>
</div>
<script>
const iframe = document.getElementById('stream-widget');
iframe.src = iframe.dataset.src + navigator.language;
setTimeout(()=>iframe.classList.add('show'), 5000);
(async () => {
const prompt = `
Today is ${new Date().toJSON().slice(0, 10)}.
Create a single-page HTML screensaver reflecting the current season (based on the date).
For example, if it's Spring, it might use floral patterns or pastel colors.
Avoid using any text. Consider a subtle animation or transition effect.`;
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html`)
const text = await response.text()
background.src = `data:text/html;charset=utf-8,${encodeURIComponent(text)}`;
const gradient = document.querySelector('.gradient');
gradient.classList.add('hidden');
})();
</script>
</body>
</html>

View file

@ -157,6 +157,10 @@
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box hidden">
<label for="Anthropic-api_key" class="label" title="">Anthropic API:</label>
<textarea id="Anthropic-api_key" name="Anthropic[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box hidden">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>

View file

@ -239,7 +239,7 @@ body:not(.white) a:visited{
max-width: 210px;
margin-left: auto;
margin-right: 8px;
margin-bottom: 8px;
margin-top: 12px;
}
.convo-title {
@ -530,6 +530,7 @@ body:not(.white) a:visited{
z-index: 100000;
top: 0;
right: 0;
animation: show_popup 0.4s;
}
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button, button.options_button {
@ -545,7 +546,6 @@ body:not(.white) a:visited{
align-items: center;
gap: 12px;
cursor: pointer;
animation: show_popup 0.4s;
height: 28px;
}
@ -712,6 +712,8 @@ form label.toogle {
position: relative;
overflow: hidden;
transition: 0.33s;
min-width: 60px;
margin-left: 0;
}
.buttons label:after,
@ -1280,12 +1282,18 @@ ul {
border: 1px solid #e4d4ffc9;
}
.settings textarea, form textarea {
form textarea {
height: 20px;
min-height: 20px;
padding: 0;
}
.settings textarea {
height: 30px;
min-height: 30px;
padding: 6px;
}
form .field .fa-xmark {
line-height: 20px;
cursor: pointer;
@ -1346,6 +1354,7 @@ form .field.saved .fa-xmark {
.settings .label, form .label, .settings label, form label {
font-size: 15px;
margin-left: var(--inner-gap);
min-width: 120px;
}
.settings .label, form .label {

View file

@ -137,14 +137,13 @@ class HtmlRenderPlugin {
}
}
}
if (window.hljs) {
hljs.addPlugin(new HtmlRenderPlugin())
hljs.addPlugin(new CopyButtonPlugin());
}
let typesetPromise = Promise.resolve();
const highlight = (container) => {
if (window.hljs) {
hljs.addPlugin(new HtmlRenderPlugin())
if (window.CopyButtonPlugin) {
hljs.addPlugin(new CopyButtonPlugin());
}
container.querySelectorAll('code:not(.hljs').forEach((el) => {
if (el.className != "hljs") {
hljs.highlightElement(el);
@ -518,6 +517,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) =>
delete new_message.synthesize;
delete new_message.finish;
delete new_message.conversation;
delete new_message.continue;
// Append message to new messages
new_messages.push(new_message)
}
@ -571,7 +571,7 @@ async function load_provider_parameters(provider) {
field_el = document.createElement("div");
field_el.classList.add("field");
field_el.classList.add("box");
if (typeof value == "object") {
if (typeof value == "object" && value != null) {
value = JSON.stringify(value, null, 4);
}
if (saved_value) {
@ -580,14 +580,15 @@ async function load_provider_parameters(provider) {
saved_value = value;
}
let placeholder;
if (key in ["api_key", "proof_token"]) {
placeholder = value.length >= 22 ? (value.substring(0, 10) + "*".repeat(8) + value.substring(value.length-10)) : value;
if (["api_key", "proof_token"].includes(key)) {
placeholder = saved_value && saved_value.length >= 22 ? (saved_value.substring(0, 12) + "*".repeat(12) + saved_value.substring(saved_value.length-12)) : value;
} else {
placeholder = value;
placeholder = value == null ? "null" : value;
}
field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`;
if (Number.isInteger(value)) {
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="4096" step="1"/><output>${escapeHtml(value)}</output>`;
if (Number.isInteger(value) && value != 1) {
max = value >= 4096 ? 8192 : 4096;
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
} else if (typeof value == "number") {
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
@ -596,10 +597,12 @@ async function load_provider_parameters(provider) {
field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
input_el = field_el.querySelector("textarea");
input_el.dataset.text = value;
if (value != null) {
input_el.dataset.text = value;
}
input_el.placeholder = placeholder;
if (!key in ["api_key", "proof_token"]) {
input_el.innerHTML = saved_value;
if (!["api_key", "proof_token"].includes(key)) {
input_el.value = saved_value;
} else {
input_el.dataset.saved_value = saved_value;
}
@ -610,14 +613,16 @@ async function load_provider_parameters(provider) {
};
input_el.onfocus = () => {
if (input_el.dataset.saved_value) {
input_el.innerHTML = input_el.dataset.saved_value;
input_el.value = input_el.dataset.saved_value;
} else if (["api_key", "proof_token"].includes(key)) {
input_el.value = input_el.dataset.text;
}
input_el.style.removeProperty("height");
input_el.style.height = (input_el.scrollHeight) + "px";
}
input_el.onblur = () => {
input_el.style.removeProperty("height");
if (key in ["api_key", "proof_token"]) {
if (["api_key", "proof_token"].includes(key)) {
input_el.value = "";
}
}
@ -642,13 +647,14 @@ async function load_provider_parameters(provider) {
input_el.value = input_el.dataset.value;
input_el.nextElementSibling.value = input_el.dataset.value;
} else if (input_el.dataset.text) {
input_el.innerHTML = input_el.dataset.text;
input_el.value = input_el.dataset.text;
}
delete input_el.dataset.saved_value;
appStorage.removeItem(el_id);
field_el.classList.remove("saved");
}
});
provider_forms.prepend(form_el);
provider_forms.appendChild(form_el);
}
}
@ -1004,8 +1010,9 @@ const load_conversation = async (conversation_id, scroll=true) => {
let lines = buffer.trim().split("\n");
let lastLine = lines[lines.length - 1];
let newContent = item.content;
if (newContent.startsWith("```\n")) {
newContent = item.content.substring(4);
if (newContent.startsWith("```")) {
const index = str.indexOf("\n");
newContent = newContent.substring(index);
}
if (newContent.startsWith(lastLine)) {
newContent = newContent.substring(lastLine.length);
@ -1073,7 +1080,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
reason = "stop"
}
}
if (reason == "max_tokens" || reason == "error") {
if (reason == "length" || reason == "max_tokens" || reason == "error") {
actions.push("continue")
}
}
@ -1405,22 +1412,23 @@ function open_settings() {
const register_settings_storage = async () => {
const optionElements = document.querySelectorAll(optionElementsSelector);
optionElements.forEach((element) => {
element.name = element.name || element.id;
if (element.type == "textarea") {
element.addEventListener('input', async (event) => {
appStorage.setItem(element.id, element.value);
appStorage.setItem(element.name, element.value);
});
} else {
element.addEventListener('change', async (event) => {
switch (element.type) {
case "checkbox":
appStorage.setItem(element.id, element.checked);
appStorage.setItem(element.name, element.checked);
break;
case "select-one":
appStorage.setItem(element.id, element.selectedIndex);
appStorage.setItem(element.name, element.value);
break;
case "text":
case "number":
appStorage.setItem(element.id, element.value);
appStorage.setItem(element.name, element.value);
break;
default:
console.warn("Unresolved element type");
@ -1433,7 +1441,8 @@ const register_settings_storage = async () => {
const load_settings_storage = async () => {
const optionElements = document.querySelectorAll(optionElementsSelector);
optionElements.forEach((element) => {
if (!(value = appStorage.getItem(element.id))) {
element.name = element.name || element.id;
if (!(value = appStorage.getItem(element.name))) {
return;
}
if (value) {
@ -1442,7 +1451,7 @@ const load_settings_storage = async () => {
element.checked = value === "true";
break;
case "select-one":
element.selectedIndex = parseInt(value);
element.value = value;
break;
case "text":
case "number":
@ -1683,7 +1692,7 @@ async function on_api() {
console.error(e)
// Redirect to show basic authenfication
if (document.location.pathname == "/chat/") {
document.location.href = `/chat/error`;
//document.location.href = `/chat/error`;
}
}
register_settings_storage();
@ -1753,7 +1762,7 @@ async function load_version() {
new_version = document.createElement("div");
new_version.classList.add("new_version");
const link = `<a href="${release_url}" target="_blank" title="${title}">v${versions["latest_version"]}</a>`;
new_version.innerHTML = `g4f ${link}&nbsp;&nbsp;🆕`;
new_version.innerHTML = `G4F ${link}&nbsp;&nbsp;🆕`;
new_version.addEventListener("click", ()=>new_version.parentElement.removeChild(new_version));
document.body.appendChild(new_version);
} else {
@ -1951,7 +1960,10 @@ async function api(ressource, args=null, files=null, message_id=null) {
return read_response(response, message_id, args.provider || null);
}
response = await fetch(url, {headers: headers});
return await response.json();
if (response.status == 200) {
return await response.json();
}
console.error(response);
}
async function read_response(response, message_id, provider) {
@ -1987,19 +1999,19 @@ function get_api_key_by_provider(provider) {
return api_key;
}
async function load_provider_models(providerIndex=null) {
if (!providerIndex) {
providerIndex = providerSelect.selectedIndex;
async function load_provider_models(provider=null) {
if (!provider) {
provider = providerSelect.value;
}
modelProvider.innerHTML = '';
const provider = providerSelect.options[providerIndex].value;
modelProvider.name = `model[${provider}]`;
if (!provider) {
modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden");
return;
}
const models = await api('models', provider);
if (models.length > 0) {
if (models && models.length > 0) {
modelSelect.classList.add("hidden");
modelProvider.classList.remove("hidden");
models.forEach((model) => {
@ -2010,6 +2022,10 @@ async function load_provider_models(providerIndex=null) {
option.selected = model.default;
modelProvider.appendChild(option);
});
let value = appStorage.getItem(modelProvider.name);
if (value) {
modelProvider.value = value;
}
} else {
modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden");

View file

@ -7,7 +7,7 @@ class CopyButtonPlugin {
el,
text
}) {
if (el.classList.contains("language-plaintext")) {
if (el.parentElement.tagName != "PRE") {
return;
}
let button = Object.assign(document.createElement("button"), {

View file

@ -64,7 +64,6 @@ class Api:
"parent": getattr(provider, "parent", None),
"image": getattr(provider, "image_models", None) is not None,
"vision": getattr(provider, "default_vision_model", None) is not None,
"webdriver": "webdriver" in provider.get_parameters(),
"auth": provider.needs_auth,
} for provider in __providers__ if provider.working]

View file

@ -6,4 +6,6 @@ def create_app() -> Flask:
template_folder = os.path.join(sys._MEIPASS, "client")
else:
template_folder = "../client"
return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
app.config["TEMPLATES_AUTO_RELOAD"] = True # Enable auto reload in debug mode
return app

View file

@ -14,9 +14,12 @@ from werkzeug.utils import secure_filename
from ...image import is_allowed_extension, to_image
from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...cookies import get_cookies_dir
from ... import ChatCompletion
from .api import Api
logger = logging.getLogger(__name__)
@ -101,6 +104,44 @@ class Backend_Api(Api):
}
}
@app.route('/backend-api/v2/create', methods=['GET', 'POST'])
def create():
try:
tool_calls = [{
"function": {
"name": "bucket_tool"
},
"type": "function"
}]
web_search = request.args.get("web_search")
if web_search:
tool_calls.append({
"function": {
"name": "search_tool",
"arguments": {"query": web_search, "instructions": ""} if web_search != "true" else {}
},
"type": "function"
})
do_filter_markdown = request.args.get("filter_markdown")
response = iter_run_tools(
ChatCompletion.create,
model=request.args.get("model"),
messages=[{"role": "user", "content": request.args.get("prompt")}],
provider=request.args.get("provider", None),
stream=not do_filter_markdown,
ignore_stream=not request.args.get("stream"),
tool_calls=tool_calls,
)
if do_filter_markdown:
return Response(filter_markdown(response, do_filter_markdown), mimetype='text/plain')
def cast_str():
for chunk in response:
yield str(chunk)
return Response(cast_str(), mimetype='text/plain')
except Exception as e:
logger.exception(e)
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
@app.route('/backend-api/v2/buckets', methods=['GET'])
def list_buckets():
try:

View file

@ -9,7 +9,7 @@ class Website:
self.app = app
self.routes = {
'/': {
'function': redirect_home,
'function': self._home,
'methods': ['GET', 'POST']
},
'/chat/': {
@ -40,4 +40,7 @@ class Website:
return render_template('index.html', chat_id=conversation_id)
def _index(self):
return render_template('index.html', chat_id=str(uuid.uuid4()))
return render_template('index.html', chat_id=str(uuid.uuid4()))
def _home(self):
return render_template('home.html')

View file

@ -72,7 +72,7 @@ default = Model(
Blackbox,
Copilot,
DeepInfraChat,
Airforce,
Airforce,
Cloudflare,
PollinationsAI,
ChatGptEs,
@ -98,26 +98,26 @@ gpt_35_turbo = Model(
gpt_4 = Model(
name = 'gpt-4',
base_provider = 'OpenAI',
best_provider = IterListProvider([DDG, Blackbox, ChatGptEs, PollinationsAI, Copilot, OpenaiChat, Liaobots, Airforce, Mhystical])
best_provider = IterListProvider([DDG, Blackbox, ChatGptEs, PollinationsAI, Copilot, OpenaiChat, Liaobots, Mhystical])
)
gpt_4_turbo = Model(
name = 'gpt-4-turbo',
base_provider = 'OpenAI',
best_provider = Airforce
best_provider = None
)
# gpt-4o
gpt_4o = Model(
name = 'gpt-4o',
base_provider = 'OpenAI',
best_provider = IterListProvider([Blackbox, ChatGptEs, PollinationsAI, DarkAI, ChatGpt, Airforce, Liaobots, OpenaiChat])
best_provider = IterListProvider([Blackbox, ChatGptEs, PollinationsAI, DarkAI, ChatGpt, Liaobots, OpenaiChat])
)
gpt_4o_mini = Model(
name = 'gpt-4o-mini',
base_provider = 'OpenAI',
best_provider = IterListProvider([DDG, ChatGptEs, Pizzagpt, ChatGpt, Airforce, RubiksAI, Liaobots, OpenaiChat])
best_provider = IterListProvider([DDG, ChatGptEs, Pizzagpt, ChatGpt, RubiksAI, Liaobots, OpenaiChat])
)
# o1
@ -136,7 +136,7 @@ o1_preview = Model(
o1_mini = Model(
name = 'o1-mini',
base_provider = 'OpenAI',
best_provider = IterListProvider([Liaobots, Airforce])
best_provider = IterListProvider([Liaobots])
)
### GigaChat ###

View file

@ -5,8 +5,10 @@ import asyncio
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
import json
from inspect import signature, Parameter
from typing import Optional, _GenericAlias
from typing import Optional, Awaitable, _GenericAlias
from pathlib import Path
try:
from types import NoneType
except ImportError:
@ -15,9 +17,10 @@ except ImportError:
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider
from .asyncio import get_running_loop, to_sync_generator
from .response import BaseConversation
from .response import BaseConversation, AuthResult
from .helper import concat_chunks, async_concat_chunks
from ..errors import ModelNotSupportedError
from ..cookies import get_cookies_dir
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
from .. import debug
SAFE_PARAMETERS = [
@ -33,15 +36,14 @@ SAFE_PARAMETERS = [
]
BASIC_PARAMETERS = {
"provider": None,
"model": "",
"messages": [],
"provider": None,
"stream": False,
"timeout": 0,
"response_format": None,
"max_tokens": None,
"stop": None,
"web_search": False,
}
PARAMETER_EXAMPLES = {
@ -99,7 +101,7 @@ class AbstractProvider(BaseProvider):
loop.run_in_executor(executor, create_func),
timeout=timeout
)
@classmethod
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
params = {name: parameter for name, parameter in signature(
@ -109,12 +111,8 @@ class AbstractProvider(BaseProvider):
).parameters.items() if name in SAFE_PARAMETERS
and (name != "stream" or cls.supports_stream)}
if as_json:
def get_type_as_var(annotation: type, key: str):
if key == "model":
return getattr(cls, "default_model", "")
elif key == "stream":
return cls.supports_stream
elif key in PARAMETER_EXAMPLES:
def get_type_as_var(annotation: type, key: str, default):
if key in PARAMETER_EXAMPLES:
if key == "messages" and not cls.supports_system_message:
return [PARAMETER_EXAMPLES[key][-1]]
return PARAMETER_EXAMPLES[key]
@ -137,18 +135,21 @@ class AbstractProvider(BaseProvider):
return {}
elif annotation is None:
return None
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ is Optional:
return get_type_as_var(annotation.__args__[0])
elif annotation == "str" or annotation == "list[str]":
return default
elif isinstance(annotation, _GenericAlias):
if annotation.__origin__ is Optional:
return get_type_as_var(annotation.__args__[0])
else:
return str(annotation)
return { name: (
param.default
if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
else get_type_as_var(param.annotation if isinstance(param, Parameter) else type(param), name)
else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
) for name, param in {
**BASIC_PARAMETERS,
**{"provider": cls.__name__},
**params
**params,
**{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
}.items()}
return params
@ -310,6 +311,12 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
raise NotImplementedError()
create_authed = create_completion
create_authed_async = create_async
create_async_authed = create_async_generator
class ProviderModelMixin:
default_model: str = None
models: list[str] = []
@ -334,4 +341,113 @@ class ProviderModelMixin:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
cls.last_model = model
debug.last_model = model
return model
return model
class RaiseErrorMixin():
@staticmethod
def raise_error(data: dict):
if "error_message" in data:
raise ResponseError(data["error_message"])
elif "error" in data:
if "code" in data["error"]:
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
elif "message" in data["error"]:
raise ResponseError(data["error"]["message"])
else:
raise ResponseError(data["error"])
class AuthedMixin():
@classmethod
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
if "api_key" not in kwargs:
raise MissingAuthError(f"API key is required for {cls.__name__}")
return None
@classmethod
def create_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
auth_result = {}
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_completion(model, messages, **kwargs)
finally:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedMixin(AuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await cls.create_async(model, messages, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
@classmethod
def create_async_authed_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> Awaitable[AsyncResult]:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))

View file

@ -105,6 +105,10 @@ class Usage(ResponseType, JsonMixin):
def __str__(self) -> str:
return ""
class AuthResult(JsonMixin):
def __str__(self) -> str:
return ""
class TitleGeneration(ResponseType):
def __init__(self, title: str) -> None:
self.title = title
@ -182,4 +186,5 @@ class ImagePreview(ImageResponse):
return super().__str__()
class Parameters(ResponseType, JsonMixin):
pass
def __str__(self):
return ""

View file

@ -40,7 +40,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
)
elif tool.get("function", {}).get("name") == "continue":
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue after this line.\n{last_line}"
content = f"Carry on from this point:\n{last_line}"
messages.append({"role": "user", "content": content})
elif tool.get("function", {}).get("name") == "bucket_tool":
def on_bucket(match):
@ -90,7 +90,7 @@ def iter_run_tools(
elif tool.get("function", {}).get("name") == "continue_tool":
if provider not in ("OpenaiAccount", "HuggingFace"):
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue after this line:\n{last_line}"
content = f"Carry on from this point:\n{last_line}"
messages.append({"role": "user", "content": content})
else:
# Enable provider native continue

View file

@ -5,6 +5,7 @@ import json
import hashlib
from pathlib import Path
from urllib.parse import urlparse
from datetime import datetime
import datetime
import asyncio
@ -65,7 +66,7 @@ class SearchResultEntry():
def set_text(self, text: str):
self.text = text
def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
def scrape_text(html: str, max_words: int = None, add_source=True) -> Iterator[str]:
source = BeautifulSoup(html, "html.parser")
soup = source
for selector in [
@ -87,7 +88,7 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
if select:
select.extract()
for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"):
for paragraph in soup.select("p, table:not(:has(p)), ul:not(:has(p)), h1, h2, h3, h4, h5, h6"):
for line in paragraph.text.splitlines():
words = [word for word in line.replace("\t", " ").split(" ") if word]
count = len(words)
@ -99,24 +100,25 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
break
yield " ".join(words) + "\n"
canonical_link = source.find("link", rel="canonical")
if canonical_link and "href" in canonical_link.attrs:
link = canonical_link["href"]
domain = urlparse(link).netloc
yield f"\nSource: [{domain}]({link})"
if add_source:
canonical_link = source.find("link", rel="canonical")
if canonical_link and "href" in canonical_link.attrs:
link = canonical_link["href"]
domain = urlparse(link).netloc
yield f"\nSource: [{domain}]({link})"
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None, add_source: bool = False) -> str:
try:
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
bucket_dir.mkdir(parents=True, exist_ok=True)
md5_hash = hashlib.md5(url.encode()).hexdigest()
cache_file = bucket_dir / f"{url.split('/')[3]}.{datetime.date.today()}.{md5_hash}.txt"
cache_file = bucket_dir / f"{url.split('?')[0].split('//')[1].replace('/', '+')[:16]}.{datetime.date.today()}.{md5_hash}.txt"
if cache_file.exists():
return cache_file.read_text()
async with session.get(url) as response:
if response.status == 200:
html = await response.text()
text = "".join(scrape_text(html, max_words))
text = "".join(scrape_text(html, max_words, add_source))
with open(cache_file, "w") as f:
f.write(text)
return text
@ -136,6 +138,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
max_results=max_results,
backend=backend,
):
if ".google.com" in result["href"]:
continue
results.append(SearchResultEntry(
result["title"],
result["href"],
@ -146,7 +150,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
requests = []
async with ClientSession(timeout=ClientTimeout(timeout)) as session:
for entry in results:
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1))))
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)), False))
texts = await asyncio.gather(*requests)
formatted_results = []
@ -173,7 +177,7 @@ async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_
query = spacy_get_keywords(prompt)
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode()
md5_hash = hashlib.md5(json_bytes).hexdigest()
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search"
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search:{datetime.date.today()}"
bucket_dir.mkdir(parents=True, exist_ok=True)
cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt"
if cache_file.exists():
@ -192,7 +196,9 @@ Instruction: {instructions}
User request:
{prompt}
"""
debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words")
debug.log(f"Web search: '{query.strip()[:50]}...'")
if isinstance(search_results, SearchResults):
debug.log(f"with {len(search_results.results)} Results {search_results.used_words} Words")
return new_prompt
def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str: