mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
commit
c5ba78c7e1
40 changed files with 800 additions and 161 deletions
|
|
@ -70,8 +70,7 @@ def read_json(text: str) -> dict:
|
|||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
print("No valid json:", text)
|
||||
return {}
|
||||
raise RuntimeError(f"Invalid JSON: {text}")
|
||||
|
||||
def read_text(text: str) -> str:
|
||||
"""
|
||||
|
|
@ -86,7 +85,8 @@ def read_text(text: str) -> str:
|
|||
match = re.search(r"```(markdown|)\n(?P<text>[\S\s]+?)\n```", text)
|
||||
if match:
|
||||
return match.group("text")
|
||||
return text
|
||||
else:
|
||||
raise RuntimeError(f"Invalid markdown: {text}")
|
||||
|
||||
def get_ai_response(prompt: str, as_json: bool = True) -> Union[dict, str]:
|
||||
"""
|
||||
|
|
@ -197,6 +197,7 @@ def create_review_prompt(pull: PullRequest, diff: str):
|
|||
return f"""Your task is to review a pull request. Instructions:
|
||||
- Write in name of g4f copilot. Don't use placeholder.
|
||||
- Write the review in GitHub Markdown format.
|
||||
- Enclose your response in backticks ```response```
|
||||
- Thank the author for contributing to the project.
|
||||
|
||||
Pull request author: {pull.user.name}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ async def test_async(provider: ProviderType):
|
|||
return False
|
||||
messages = [{"role": "user", "content": "Hello Assistant!"}]
|
||||
try:
|
||||
if "webdriver" in provider.get_parameters():
|
||||
return False
|
||||
response = await asyncio.wait_for(ChatCompletion.create_async(
|
||||
model=models.default,
|
||||
messages=messages,
|
||||
|
|
|
|||
|
|
@ -46,4 +46,4 @@ class TestBackendApi(unittest.TestCase):
|
|||
self.skipTest(e)
|
||||
except MissingRequirementsError:
|
||||
self.skipTest("search is not installed")
|
||||
self.assertTrue(len(result) >= 4)
|
||||
self.assertGreater(len(result), 0)
|
||||
|
|
|
|||
|
|
@ -1,36 +1,11 @@
|
|||
import unittest
|
||||
import asyncio
|
||||
|
||||
import g4f
|
||||
from g4f import ChatCompletion, get_last_provider
|
||||
import g4f.version
|
||||
from g4f.errors import VersionNotFoundError
|
||||
from g4f.Provider import RetryProvider
|
||||
from .mocks import ProviderMock
|
||||
|
||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
class TestGetLastProvider(unittest.TestCase):
|
||||
|
||||
def test_get_last_provider(self):
|
||||
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
self.assertEqual(get_last_provider(), ProviderMock)
|
||||
|
||||
def test_get_last_provider_retry(self):
|
||||
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
|
||||
self.assertEqual(get_last_provider(), ProviderMock)
|
||||
|
||||
def test_get_last_provider_async(self):
|
||||
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
asyncio.run(coroutine)
|
||||
self.assertEqual(get_last_provider(), ProviderMock)
|
||||
|
||||
def test_get_last_provider_as_dict(self):
|
||||
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
last_provider_dict = get_last_provider(True)
|
||||
self.assertIsInstance(last_provider_dict, dict)
|
||||
self.assertIn('name', last_provider_dict)
|
||||
self.assertEqual(ProviderMock.__name__, last_provider_dict['name'])
|
||||
|
||||
def test_get_latest_version(self):
|
||||
try:
|
||||
self.assertIsInstance(g4f.version.utils.current_version, str)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
||||
default_model = "gpt-4o-mini"
|
||||
default_model = "llama-3.1-70b-chat"
|
||||
default_image_model = "flux"
|
||||
|
||||
models = []
|
||||
|
|
@ -113,7 +113,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
"""Get the actual model name from alias"""
|
||||
return cls.model_aliases.get(model, model)
|
||||
return cls.model_aliases.get(model, model or cls.default_model)
|
||||
|
||||
@classmethod
|
||||
async def check_api_key(cls, api_key: str) -> bool:
|
||||
|
|
@ -162,6 +162,9 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"""
|
||||
Filters the full response to remove system errors and other unwanted text.
|
||||
"""
|
||||
if "Model not found or too long input. Or any other error (xD)" in response:
|
||||
raise ValueError(response)
|
||||
|
||||
filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored
|
||||
filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token
|
||||
filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1
|
||||
|
|
|
|||
|
|
@ -246,6 +246,8 @@ class BlackboxCreateAgent(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
Returns:
|
||||
AsyncResult: The response from the provider
|
||||
"""
|
||||
if not model:
|
||||
model = cls.default_model
|
||||
if model in cls.chat_models:
|
||||
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
|
||||
return text
|
||||
|
|
|
|||
|
|
@ -80,4 +80,6 @@ class ChatGptEs(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
async with session.post(cls.api_endpoint, headers=headers, data=payload) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
if "Du musst das Kästchen anklicken!" in result['data']:
|
||||
raise ValueError(result['data'])
|
||||
yield result['data']
|
||||
|
|
|
|||
|
|
@ -127,8 +127,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
response = session.post(cls.conversation_url)
|
||||
raise_for_status(response)
|
||||
conversation_id = response.json().get("id")
|
||||
conversation = Conversation(conversation_id)
|
||||
if return_conversation:
|
||||
conversation = Conversation(conversation_id)
|
||||
yield conversation
|
||||
if prompt is None:
|
||||
prompt = format_prompt_max_length(messages, 10000)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from .needs_auth.OpenaiAPI import OpenaiAPI
|
|||
"""
|
||||
|
||||
class Mhystical(OpenaiAPI):
|
||||
url = "https://api.mhystical.cc"
|
||||
label = "Mhystical"
|
||||
url = "https://mhystical.cc"
|
||||
api_endpoint = "https://api.mhystical.cc/v1/completions"
|
||||
working = True
|
||||
needs_auth = False
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import random
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
from typing import Optional
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
|
@ -170,7 +171,7 @@ class PollinationsAI(OpenaiAPI):
|
|||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
prompt = quote(messages[-1]["content"])
|
||||
prompt = quote(messages[-1]["content"] if prompt is None else prompt)
|
||||
param_string = "&".join(f"{k}={v}" for k, v in params.items())
|
||||
url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}"
|
||||
|
||||
|
|
|
|||
195
g4f/Provider/needs_auth/Anthropic.py
Normal file
195
g4f/Provider/needs_auth/Anthropic.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from ..helper import filter_none
|
||||
from ...typing import AsyncResult, Messages, ImagesType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage
|
||||
from ...errors import MissingAuthError
|
||||
from ...image import to_bytes, is_accepted_format
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
|
||||
class Anthropic(OpenaiAPI):
|
||||
label = "Anthropic API"
|
||||
url = "https://console.anthropic.com"
|
||||
login_url = "https://console.anthropic.com/settings/keys"
|
||||
working = True
|
||||
api_base = "https://api.anthropic.com/v1"
|
||||
needs_auth = True
|
||||
supports_stream = True
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
default_model = "claude-3-5-sonnet-latest"
|
||||
models = [
|
||||
default_model,
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-latest",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-latest",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307"
|
||||
]
|
||||
models_aliases = {
|
||||
"claude-3.5-sonnet": default_model,
|
||||
"claude-3-opus": "claude-3-opus-latest",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None, **kwargs):
|
||||
if not cls.models:
|
||||
url = f"https://api.anthropic.com/v1/models"
|
||||
response = requests.get(url, headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
})
|
||||
raise_for_status(response)
|
||||
models = response.json()
|
||||
cls.models = [model["id"] for model in models["data"]]
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
images: ImagesType = None,
|
||||
api_key: str = None,
|
||||
temperature: float = None,
|
||||
max_tokens: int = 4096,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
stop: list[str] = None,
|
||||
stream: bool = False,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
tools: Optional[list] = None,
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if api_key is None:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
|
||||
if images is not None:
|
||||
insert_images = []
|
||||
for image, _ in images:
|
||||
data = to_bytes(image)
|
||||
insert_images.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": is_accepted_format(data),
|
||||
"data": base64.b64encode(data).decode(),
|
||||
}
|
||||
})
|
||||
messages[-1]["content"] = [
|
||||
*insert_images,
|
||||
{
|
||||
"type": "text",
|
||||
"text": messages[-1]["content"]
|
||||
}
|
||||
]
|
||||
system = "\n".join([message for message in messages if message.get("role") == "system"])
|
||||
if system:
|
||||
messages = [message for message in messages if message.get("role") != "system"]
|
||||
else:
|
||||
system = None
|
||||
|
||||
async with StreamSession(
|
||||
proxy=proxy,
|
||||
headers=cls.get_headers(stream, api_key, headers),
|
||||
timeout=timeout,
|
||||
impersonate=impersonate,
|
||||
) as session:
|
||||
data = filter_none(
|
||||
messages=messages,
|
||||
model=cls.get_model(model, api_key=api_key),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
stop_sequences=stop,
|
||||
system=system,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**extra_data
|
||||
)
|
||||
async with session.post(f"{cls.api_base}/messages", json=data) as response:
|
||||
await raise_for_status(response)
|
||||
if not stream:
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
if "type" in data and data["type"] == "message":
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
yield content["text"]
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": { "name": content["name"], "arguments": content["input"] }
|
||||
})
|
||||
if data["stop_reason"] == "end_turn":
|
||||
yield FinishReason("stop")
|
||||
elif data["stop_reason"] == "max_tokens":
|
||||
yield FinishReason("length")
|
||||
yield Usage(**data["usage"])
|
||||
else:
|
||||
content_block = None
|
||||
partial_json = []
|
||||
tool_calls = []
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == b"[DONE]":
|
||||
break
|
||||
data = json.loads(chunk)
|
||||
cls.raise_error(data)
|
||||
if "type" in data:
|
||||
if data["type"] == "content_block_start":
|
||||
content_block = data["content_block"]
|
||||
if content_block is None:
|
||||
pass # Message start
|
||||
elif data["type"] == "content_block_delta":
|
||||
if content_block["type"] == "text":
|
||||
yield data["delta"]["text"]
|
||||
elif content_block["type"] == "tool_use":
|
||||
partial_json.append(data["delta"]["partial_json"])
|
||||
elif data["type"] == "message_delta":
|
||||
if data["delta"]["stop_reason"] == "end_turn":
|
||||
yield FinishReason("stop")
|
||||
elif data["delta"]["stop_reason"] == "max_tokens":
|
||||
yield FinishReason("length")
|
||||
yield Usage(**data["usage"])
|
||||
elif data["type"] == "content_block_stop":
|
||||
if content_block["type"] == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": content_block["id"],
|
||||
"type": "function",
|
||||
"function": { "name": content_block["name"], "arguments": partial_json.join("") }
|
||||
})
|
||||
partial_json = []
|
||||
if tool_calls:
|
||||
yield ToolCalls(tool_calls)
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
return {
|
||||
"Accept": "text/event-stream" if stream else "application/json",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{"x-api-key": api_key}
|
||||
if api_key is not None else {}
|
||||
),
|
||||
"anthropic-version": "2023-06-01",
|
||||
**({} if headers is None else headers)
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ from ...cookies import get_cookies
|
|||
class Cerebras(OpenaiAPI):
|
||||
label = "Cerebras Inference"
|
||||
url = "https://inference.cerebras.ai/"
|
||||
login_url = "https://cloud.cerebras.ai"
|
||||
api_base = "https://api.cerebras.ai/v1"
|
||||
working = True
|
||||
default_model = "llama3.1-70b"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from ... import debug
|
|||
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Google Gemini API"
|
||||
url = "https://ai.google.dev"
|
||||
login_url = "https://aistudio.google.com/u/0/apikey"
|
||||
api_base = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
working = True
|
||||
|
|
@ -24,7 +25,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
|
||||
default_model = "gemini-1.5-pro"
|
||||
default_vision_model = default_model
|
||||
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
|
||||
fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
|
||||
model_aliases = {
|
||||
"gemini-flash": "gemini-1.5-flash",
|
||||
"gemini-flash": "gemini-1.5-flash-8b",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
|
|||
class Groq(OpenaiAPI):
|
||||
label = "Groq"
|
||||
url = "https://console.groq.com/playground"
|
||||
login_url = "https://console.groq.com/keys"
|
||||
api_base = "https://api.groq.com/openai/v1"
|
||||
working = True
|
||||
default_model = "mixtral-8x7b-32768"
|
||||
|
|
|
|||
|
|
@ -47,8 +47,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
proxy: str = None,
|
||||
api_base: str = "https://api-inference.huggingface.co",
|
||||
api_key: str = None,
|
||||
max_new_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = None,
|
||||
prompt: str = None,
|
||||
action: str = None,
|
||||
extra_data: dict = {},
|
||||
|
|
@ -84,7 +84,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
else:
|
||||
params = {
|
||||
"return_full_text": False,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
**extra_data
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from .HuggingChat import HuggingChat
|
|||
class HuggingFaceAPI(OpenaiAPI):
|
||||
label = "HuggingFace (Inference API)"
|
||||
url = "https://api-inference.huggingface.co"
|
||||
login_url = "https://huggingface.co/settings/tokens"
|
||||
api_base = "https://api-inference.huggingface.co/v1"
|
||||
working = True
|
||||
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import requests
|
||||
|
||||
from ..helper import filter_none
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage
|
||||
|
|
@ -12,9 +12,10 @@ from ...errors import MissingAuthError, ResponseError
|
|||
from ...image import to_data_uri
|
||||
from ... import debug
|
||||
|
||||
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
||||
label = "OpenAI API"
|
||||
url = "https://platform.openai.com"
|
||||
login_url = "https://platform.openai.com/settings/organization/api-keys"
|
||||
api_base = "https://api.openai.com/v1"
|
||||
working = True
|
||||
needs_auth = True
|
||||
|
|
@ -141,18 +142,6 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
return FinishReason(choice["finish_reason"])
|
||||
|
||||
@staticmethod
|
||||
def raise_error(data: dict):
|
||||
if "error_message" in data:
|
||||
raise ResponseError(data["error_message"])
|
||||
elif "error" in data:
|
||||
if "code" in data["error"]:
|
||||
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
|
||||
elif "message" in data["error"]:
|
||||
raise ResponseError(data["error"]["message"])
|
||||
else:
|
||||
raise ResponseError(data["error"])
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
|
|||
class PerplexityApi(OpenaiAPI):
|
||||
label = "Perplexity API"
|
||||
url = "https://www.perplexity.ai"
|
||||
login_url = "https://www.perplexity.ai/settings/api"
|
||||
working = True
|
||||
api_base = "https://api.perplexity.ai"
|
||||
default_model = "llama-3-sonar-large-32k-online"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from ...errors import ResponseError, MissingAuthError
|
|||
|
||||
class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://replicate.com"
|
||||
login_url = "https://replicate.com/account/api-tokens"
|
||||
working = True
|
||||
needs_auth = True
|
||||
default_model = "meta/meta-llama-3-70b-instruct"
|
||||
|
|
@ -25,7 +26,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
proxy: str = None,
|
||||
timeout: int = 180,
|
||||
system_prompt: str = None,
|
||||
max_new_tokens: int = None,
|
||||
max_tokens: int = None,
|
||||
temperature: float = None,
|
||||
top_p: float = None,
|
||||
top_k: float = None,
|
||||
|
|
@ -55,7 +56,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"prompt": format_prompt(messages),
|
||||
**filter_none(
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .gigachat import *
|
||||
|
||||
from .Anthropic import Anthropic
|
||||
from .BingCreateImages import BingCreateImages
|
||||
from .Cerebras import Cerebras
|
||||
from .CopilotAccount import CopilotAccount
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
|
|||
class glhfChat(OpenaiAPI):
|
||||
label = "glhf.chat"
|
||||
url = "https://glhf.chat"
|
||||
login_url = "https://glhf.chat/users/settings/api"
|
||||
api_base = "https://glhf.chat/api/openai/v1"
|
||||
working = True
|
||||
model_aliases = {
|
||||
|
|
|
|||
|
|
@ -5,5 +5,6 @@ from .OpenaiAPI import OpenaiAPI
|
|||
class xAI(OpenaiAPI):
|
||||
label = "xAI"
|
||||
url = "https://console.x.ai"
|
||||
login_url = "https://console.x.ai"
|
||||
api_base = "https://api.x.ai/v1"
|
||||
working = True
|
||||
|
|
@ -12,7 +12,7 @@ from .errors import StreamNotSupportedError
|
|||
from .cookies import get_cookies, set_cookies
|
||||
from .providers.types import ProviderType
|
||||
from .providers.helper import concat_chunks
|
||||
from .client.service import get_model_and_provider, get_last_provider
|
||||
from .client.service import get_model_and_provider
|
||||
|
||||
#Configure "g4f" logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -47,7 +47,8 @@ class ChatCompletion:
|
|||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
|
||||
result = provider.create_completion(model, messages, stream=stream, **kwargs)
|
||||
create_method = provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion
|
||||
result = create_method(model, messages, stream=stream, **kwargs)
|
||||
|
||||
return result if stream else concat_chunks(result)
|
||||
|
||||
|
|
@ -72,7 +73,9 @@ class ChatCompletion:
|
|||
kwargs["ignore_stream"] = True
|
||||
|
||||
if stream:
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
if hasattr(provider, "create_async_authed_generator"):
|
||||
return provider.create_async_authed_generator(model, messages, **kwargs)
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
return provider.create_async_generator(model, messages, **kwargs)
|
||||
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
|
||||
|
||||
|
|
|
|||
|
|
@ -197,11 +197,10 @@ class Api:
|
|||
)
|
||||
|
||||
def register_routes(self):
|
||||
@self.app.get("/")
|
||||
async def read_root():
|
||||
if AppConfig.gui:
|
||||
return RedirectResponse("/chat/", 302)
|
||||
return RedirectResponse("/v1", 302)
|
||||
if not AppConfig.gui:
|
||||
@self.app.get("/")
|
||||
async def read_root():
|
||||
return RedirectResponse("/v1", 302)
|
||||
|
||||
@self.app.get("/v1")
|
||||
async def read_root_v1():
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ import string
|
|||
import asyncio
|
||||
import aiohttp
|
||||
import base64
|
||||
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
|
||||
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
|
||||
|
||||
from ..image import ImageResponse, copy_images, images_dir
|
||||
from ..image import ImageResponse, copy_images
|
||||
from ..typing import Messages, ImageType
|
||||
from ..providers.types import ProviderType, BaseRetryProvider
|
||||
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
|
||||
|
|
@ -22,7 +22,7 @@ from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
|||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
|
||||
from .. import debug
|
||||
|
||||
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
|
|
@ -220,7 +220,7 @@ class Completions:
|
|||
ignore_working: Optional[bool] = False,
|
||||
ignore_stream: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> IterResponse:
|
||||
) -> ChatCompletion:
|
||||
model, provider = get_model_and_provider(
|
||||
model,
|
||||
self.provider if provider is None else provider,
|
||||
|
|
@ -236,7 +236,7 @@ class Completions:
|
|||
kwargs["ignore_stream"] = True
|
||||
|
||||
response = iter_run_tools(
|
||||
provider.create_completion,
|
||||
provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion,
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
|
@ -248,9 +248,6 @@ class Completions:
|
|||
),
|
||||
**kwargs
|
||||
)
|
||||
if asyncio.iscoroutinefunction(provider.create_completion):
|
||||
# Run the asynchronous function in an event loop
|
||||
response = asyncio.run(response)
|
||||
if stream and hasattr(response, '__aiter__'):
|
||||
# It's an async generator, wrap it into a sync iterator
|
||||
response = to_sync_generator(response)
|
||||
|
|
@ -264,6 +261,14 @@ class Completions:
|
|||
else:
|
||||
return next(response)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
**kwargs
|
||||
) -> IterResponse:
|
||||
return self.create(messages, model, stream=True, **kwargs)
|
||||
|
||||
class Chat:
|
||||
completions: Completions
|
||||
|
||||
|
|
@ -507,7 +512,7 @@ class AsyncCompletions:
|
|||
ignore_working: Optional[bool] = False,
|
||||
ignore_stream: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
|
||||
) -> Awaitable[ChatCompletion]:
|
||||
model, provider = get_model_and_provider(
|
||||
model,
|
||||
self.provider if provider is None else provider,
|
||||
|
|
@ -521,6 +526,8 @@ class AsyncCompletions:
|
|||
kwargs["images"] = [(image, image_name)]
|
||||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
if hasattr(provider, "create_async_authed_generator"):
|
||||
create_handler = provider.create_async_authed_generator
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
create_handler = provider.create_async_generator
|
||||
else:
|
||||
|
|
@ -538,10 +545,20 @@ class AsyncCompletions:
|
|||
),
|
||||
**kwargs
|
||||
)
|
||||
if not hasattr(response, '__aiter__'):
|
||||
response = to_async_iterator(response)
|
||||
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
||||
response = async_iter_append_model_and_provider(response, model, provider)
|
||||
return response if stream else anext(response)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
**kwargs
|
||||
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
|
||||
return self.create(messages, model, stream=True, **kwargs)
|
||||
|
||||
class AsyncImages(Images):
|
||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||
self.client: AsyncClient = client
|
||||
|
|
|
|||
|
|
@ -5,6 +5,22 @@ import logging
|
|||
|
||||
from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
|
||||
|
||||
def filter_markdown(text: str, allowd_types=None, default=None) -> str:
|
||||
"""
|
||||
Parses code block from a string.
|
||||
|
||||
Args:
|
||||
text (str): A string containing a code block.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary parsed from the code block.
|
||||
"""
|
||||
match = re.search(r"```(.+)\n(?P<code>[\S\s]+?)(\n```|$)", text)
|
||||
if match:
|
||||
if allowd_types is None or match.group(1) in allowd_types:
|
||||
return match.group("code")
|
||||
return default
|
||||
|
||||
def filter_json(text: str) -> str:
|
||||
"""
|
||||
Parses JSON code block from a string.
|
||||
|
|
@ -15,10 +31,7 @@ def filter_json(text: str) -> str:
|
|||
Returns:
|
||||
dict: A dictionary parsed from the JSON code block.
|
||||
"""
|
||||
match = re.search(r"```(json|)\n(?P<code>[\S\s]+?)\n```", text)
|
||||
if match:
|
||||
return match.group("code")
|
||||
return text
|
||||
return filter_markdown(text, ["", "json"], text)
|
||||
|
||||
def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
|
||||
first = -1
|
||||
|
|
|
|||
230
g4f/gui/client/home.html
Normal file
230
g4f/gui/client/home.html
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>G4F GUI</title>
|
||||
<style>
|
||||
:root {
|
||||
--colour-1: #000000;
|
||||
--colour-2: #ccc;
|
||||
--colour-3: #e4d4ff;
|
||||
--colour-4: #f0f0f0;
|
||||
--colour-5: #181818;
|
||||
--colour-6: #242424;
|
||||
--accent: #8b3dff;
|
||||
--gradient: #1a1a1a;
|
||||
--background: #16101b;
|
||||
--size: 70vw;
|
||||
--top: 50%;
|
||||
--blur: 40px;
|
||||
--opacity: 0.6;
|
||||
}
|
||||
|
||||
@import url("https://fonts.googleapis.com/css2?family=Inter:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||
|
||||
.gradient {
|
||||
position: absolute;
|
||||
z-index: -1;
|
||||
left: 50vw;
|
||||
border-radius: 50%;
|
||||
background: radial-gradient(circle at center, var(--accent), var(--gradient));
|
||||
width: var(--size);
|
||||
height: var(--size);
|
||||
top: var(--top);
|
||||
transform: translate(-50%, -50%);
|
||||
filter: blur(var(--blur)) opacity(var(--opacity));
|
||||
animation: zoom_gradient 6s infinite alternate;
|
||||
display: none;
|
||||
max-height: 100%;
|
||||
transition: max-height 0.25s ease-in;
|
||||
}
|
||||
|
||||
.gradient.hidden {
|
||||
max-height: 0;
|
||||
transition: max-height 0.15s ease-out;
|
||||
}
|
||||
|
||||
@media only screen and (min-width: 40em) {
|
||||
body .gradient{
|
||||
display: block;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes zoom_gradient {
|
||||
0% {
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
}
|
||||
100% {
|
||||
transform: translate(-50%, -50%) scale(1.2);
|
||||
}
|
||||
}
|
||||
|
||||
/* Body and text color */
|
||||
body {
|
||||
background: var(--background);
|
||||
color: var(--colour-3);
|
||||
font-family: "Inter", sans-serif;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
overflow: hidden;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Container for the main content */
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100%;
|
||||
text-align: center;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
header {
|
||||
font-size: 3rem;
|
||||
text-transform: uppercase;
|
||||
margin: 20px;
|
||||
color: var(--colour-4);
|
||||
}
|
||||
|
||||
iframe {
|
||||
background: transparent;
|
||||
width: 100%;
|
||||
border: none;
|
||||
}
|
||||
|
||||
#background {
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
iframe.stream {
|
||||
max-height: 0;
|
||||
transition: max-height 0.15s ease-out;
|
||||
}
|
||||
|
||||
iframe.stream.show {
|
||||
max-height: 1000px;
|
||||
height: 1000px;
|
||||
transition: max-height 0.25s ease-in;
|
||||
background: rgba(255,255,255,0.7);
|
||||
border-top: 2px solid rgba(255,255,255,0.5);
|
||||
}
|
||||
|
||||
.description {
|
||||
font-size: 1.2rem;
|
||||
margin-bottom: 30px;
|
||||
color: var(--colour-2);
|
||||
} return app
|
||||
|
||||
|
||||
.input-field {
|
||||
width: 80%;
|
||||
max-width: 400px;
|
||||
padding: 12px;
|
||||
margin: 10px 0;
|
||||
border: 2px solid var(--colour-6);
|
||||
background-color: var(--colour-5);
|
||||
color: var(--colour-3);
|
||||
border-radius: 8px;
|
||||
font-size: 1.1rem;
|
||||
}
|
||||
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.button {
|
||||
background-color: var(--accent);
|
||||
color: var(--colour-3);
|
||||
border: none;
|
||||
padding: 15px 30px;
|
||||
font-size: 1.1rem;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s ease;
|
||||
margin-top: 15px;
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.button:hover {
|
||||
background-color: #7a2ccd;
|
||||
}
|
||||
|
||||
.footer {
|
||||
margin-top: 30px;
|
||||
font-size: 0.9rem;
|
||||
color: var(--colour-2);
|
||||
}
|
||||
|
||||
/* Animation for the gradient circle */
|
||||
@keyframes zoom_gradient {
|
||||
0% {
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
}
|
||||
100% {
|
||||
transform: translate(-50%, -50%) scale(1.5);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<iframe id="background"></iframe>
|
||||
|
||||
<!-- Gradient Background Circle -->
|
||||
<div class="gradient"></div>
|
||||
|
||||
<!-- Main Content -->
|
||||
<div class="container">
|
||||
<header>
|
||||
G4F GUI
|
||||
</header>
|
||||
<div class="description">
|
||||
Welcome to the G4F GUI! <br>
|
||||
Your AI assistant is ready to assist you.
|
||||
</div>
|
||||
|
||||
<!-- Input and Button -->
|
||||
<form action="/chat/">
|
||||
<!--
|
||||
<input type="text" name="prompt" class="input-field" placeholder="Enter your query...">
|
||||
-->
|
||||
<button class="button">Open Chat</button>
|
||||
</form>
|
||||
|
||||
<!-- Footer -->
|
||||
<div class="footer">
|
||||
<p>© 2025 G4F. All Rights Reserved.</p>
|
||||
<p>Powered by the G4F framework</p>
|
||||
</div>
|
||||
|
||||
<iframe id="stream-widget" class="stream" data-src="/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=news in " class="" frameborder="0"></iframe>
|
||||
</div>
|
||||
<script>
|
||||
const iframe = document.getElementById('stream-widget');
|
||||
iframe.src = iframe.dataset.src + navigator.language;
|
||||
setTimeout(()=>iframe.classList.add('show'), 5000);
|
||||
|
||||
(async () => {
|
||||
const prompt = `
|
||||
Today is ${new Date().toJSON().slice(0, 10)}.
|
||||
Create a single-page HTML screensaver reflecting the current season (based on the date).
|
||||
For example, if it's Spring, it might use floral patterns or pastel colors.
|
||||
Avoid using any text. Consider a subtle animation or transition effect.`;
|
||||
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html`)
|
||||
const text = await response.text()
|
||||
background.src = `data:text/html;charset=utf-8,${encodeURIComponent(text)}`;
|
||||
const gradient = document.querySelector('.gradient');
|
||||
gradient.classList.add('hidden');
|
||||
})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -157,6 +157,10 @@
|
|||
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
|
||||
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
<div class="field box hidden">
|
||||
<label for="Anthropic-api_key" class="label" title="">Anthropic API:</label>
|
||||
<textarea id="Anthropic-api_key" name="Anthropic[api_key]" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
<div class="field box hidden">
|
||||
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
|
||||
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ body:not(.white) a:visited{
|
|||
max-width: 210px;
|
||||
margin-left: auto;
|
||||
margin-right: 8px;
|
||||
margin-bottom: 8px;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
.convo-title {
|
||||
|
|
@ -530,6 +530,7 @@ body:not(.white) a:visited{
|
|||
z-index: 100000;
|
||||
top: 0;
|
||||
right: 0;
|
||||
animation: show_popup 0.4s;
|
||||
}
|
||||
|
||||
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button, button.options_button {
|
||||
|
|
@ -545,7 +546,6 @@ body:not(.white) a:visited{
|
|||
align-items: center;
|
||||
gap: 12px;
|
||||
cursor: pointer;
|
||||
animation: show_popup 0.4s;
|
||||
height: 28px;
|
||||
}
|
||||
|
||||
|
|
@ -712,6 +712,8 @@ form label.toogle {
|
|||
position: relative;
|
||||
overflow: hidden;
|
||||
transition: 0.33s;
|
||||
min-width: 60px;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.buttons label:after,
|
||||
|
|
@ -1280,12 +1282,18 @@ ul {
|
|||
border: 1px solid #e4d4ffc9;
|
||||
}
|
||||
|
||||
.settings textarea, form textarea {
|
||||
form textarea {
|
||||
height: 20px;
|
||||
min-height: 20px;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.settings textarea {
|
||||
height: 30px;
|
||||
min-height: 30px;
|
||||
padding: 6px;
|
||||
}
|
||||
|
||||
form .field .fa-xmark {
|
||||
line-height: 20px;
|
||||
cursor: pointer;
|
||||
|
|
@ -1346,6 +1354,7 @@ form .field.saved .fa-xmark {
|
|||
.settings .label, form .label, .settings label, form label {
|
||||
font-size: 15px;
|
||||
margin-left: var(--inner-gap);
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
.settings .label, form .label {
|
||||
|
|
|
|||
|
|
@ -137,14 +137,13 @@ class HtmlRenderPlugin {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (window.hljs) {
|
||||
hljs.addPlugin(new HtmlRenderPlugin())
|
||||
hljs.addPlugin(new CopyButtonPlugin());
|
||||
}
|
||||
let typesetPromise = Promise.resolve();
|
||||
const highlight = (container) => {
|
||||
if (window.hljs) {
|
||||
hljs.addPlugin(new HtmlRenderPlugin())
|
||||
if (window.CopyButtonPlugin) {
|
||||
hljs.addPlugin(new CopyButtonPlugin());
|
||||
}
|
||||
container.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
if (el.className != "hljs") {
|
||||
hljs.highlightElement(el);
|
||||
|
|
@ -518,6 +517,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) =>
|
|||
delete new_message.synthesize;
|
||||
delete new_message.finish;
|
||||
delete new_message.conversation;
|
||||
delete new_message.continue;
|
||||
// Append message to new messages
|
||||
new_messages.push(new_message)
|
||||
}
|
||||
|
|
@ -571,7 +571,7 @@ async function load_provider_parameters(provider) {
|
|||
field_el = document.createElement("div");
|
||||
field_el.classList.add("field");
|
||||
field_el.classList.add("box");
|
||||
if (typeof value == "object") {
|
||||
if (typeof value == "object" && value != null) {
|
||||
value = JSON.stringify(value, null, 4);
|
||||
}
|
||||
if (saved_value) {
|
||||
|
|
@ -580,14 +580,15 @@ async function load_provider_parameters(provider) {
|
|||
saved_value = value;
|
||||
}
|
||||
let placeholder;
|
||||
if (key in ["api_key", "proof_token"]) {
|
||||
placeholder = value.length >= 22 ? (value.substring(0, 10) + "*".repeat(8) + value.substring(value.length-10)) : value;
|
||||
if (["api_key", "proof_token"].includes(key)) {
|
||||
placeholder = saved_value && saved_value.length >= 22 ? (saved_value.substring(0, 12) + "*".repeat(12) + saved_value.substring(saved_value.length-12)) : value;
|
||||
} else {
|
||||
placeholder = value;
|
||||
placeholder = value == null ? "null" : value;
|
||||
}
|
||||
field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`;
|
||||
if (Number.isInteger(value)) {
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="4096" step="1"/><output>${escapeHtml(value)}</output>`;
|
||||
if (Number.isInteger(value) && value != 1) {
|
||||
max = value >= 4096 ? 8192 : 4096;
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
|
||||
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
|
||||
} else if (typeof value == "number") {
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
|
||||
|
|
@ -596,10 +597,12 @@ async function load_provider_parameters(provider) {
|
|||
field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`;
|
||||
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
|
||||
input_el = field_el.querySelector("textarea");
|
||||
input_el.dataset.text = value;
|
||||
if (value != null) {
|
||||
input_el.dataset.text = value;
|
||||
}
|
||||
input_el.placeholder = placeholder;
|
||||
if (!key in ["api_key", "proof_token"]) {
|
||||
input_el.innerHTML = saved_value;
|
||||
if (!["api_key", "proof_token"].includes(key)) {
|
||||
input_el.value = saved_value;
|
||||
} else {
|
||||
input_el.dataset.saved_value = saved_value;
|
||||
}
|
||||
|
|
@ -610,14 +613,16 @@ async function load_provider_parameters(provider) {
|
|||
};
|
||||
input_el.onfocus = () => {
|
||||
if (input_el.dataset.saved_value) {
|
||||
input_el.innerHTML = input_el.dataset.saved_value;
|
||||
input_el.value = input_el.dataset.saved_value;
|
||||
} else if (["api_key", "proof_token"].includes(key)) {
|
||||
input_el.value = input_el.dataset.text;
|
||||
}
|
||||
input_el.style.removeProperty("height");
|
||||
input_el.style.height = (input_el.scrollHeight) + "px";
|
||||
}
|
||||
input_el.onblur = () => {
|
||||
input_el.style.removeProperty("height");
|
||||
if (key in ["api_key", "proof_token"]) {
|
||||
if (["api_key", "proof_token"].includes(key)) {
|
||||
input_el.value = "";
|
||||
}
|
||||
}
|
||||
|
|
@ -642,13 +647,14 @@ async function load_provider_parameters(provider) {
|
|||
input_el.value = input_el.dataset.value;
|
||||
input_el.nextElementSibling.value = input_el.dataset.value;
|
||||
} else if (input_el.dataset.text) {
|
||||
input_el.innerHTML = input_el.dataset.text;
|
||||
input_el.value = input_el.dataset.text;
|
||||
}
|
||||
delete input_el.dataset.saved_value;
|
||||
appStorage.removeItem(el_id);
|
||||
field_el.classList.remove("saved");
|
||||
}
|
||||
});
|
||||
provider_forms.prepend(form_el);
|
||||
provider_forms.appendChild(form_el);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1004,8 +1010,9 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
let lines = buffer.trim().split("\n");
|
||||
let lastLine = lines[lines.length - 1];
|
||||
let newContent = item.content;
|
||||
if (newContent.startsWith("```\n")) {
|
||||
newContent = item.content.substring(4);
|
||||
if (newContent.startsWith("```")) {
|
||||
const index = str.indexOf("\n");
|
||||
newContent = newContent.substring(index);
|
||||
}
|
||||
if (newContent.startsWith(lastLine)) {
|
||||
newContent = newContent.substring(lastLine.length);
|
||||
|
|
@ -1073,7 +1080,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
reason = "stop"
|
||||
}
|
||||
}
|
||||
if (reason == "max_tokens" || reason == "error") {
|
||||
if (reason == "length" || reason == "max_tokens" || reason == "error") {
|
||||
actions.push("continue")
|
||||
}
|
||||
}
|
||||
|
|
@ -1405,22 +1412,23 @@ function open_settings() {
|
|||
const register_settings_storage = async () => {
|
||||
const optionElements = document.querySelectorAll(optionElementsSelector);
|
||||
optionElements.forEach((element) => {
|
||||
element.name = element.name || element.id;
|
||||
if (element.type == "textarea") {
|
||||
element.addEventListener('input', async (event) => {
|
||||
appStorage.setItem(element.id, element.value);
|
||||
appStorage.setItem(element.name, element.value);
|
||||
});
|
||||
} else {
|
||||
element.addEventListener('change', async (event) => {
|
||||
switch (element.type) {
|
||||
case "checkbox":
|
||||
appStorage.setItem(element.id, element.checked);
|
||||
appStorage.setItem(element.name, element.checked);
|
||||
break;
|
||||
case "select-one":
|
||||
appStorage.setItem(element.id, element.selectedIndex);
|
||||
appStorage.setItem(element.name, element.value);
|
||||
break;
|
||||
case "text":
|
||||
case "number":
|
||||
appStorage.setItem(element.id, element.value);
|
||||
appStorage.setItem(element.name, element.value);
|
||||
break;
|
||||
default:
|
||||
console.warn("Unresolved element type");
|
||||
|
|
@ -1433,7 +1441,8 @@ const register_settings_storage = async () => {
|
|||
const load_settings_storage = async () => {
|
||||
const optionElements = document.querySelectorAll(optionElementsSelector);
|
||||
optionElements.forEach((element) => {
|
||||
if (!(value = appStorage.getItem(element.id))) {
|
||||
element.name = element.name || element.id;
|
||||
if (!(value = appStorage.getItem(element.name))) {
|
||||
return;
|
||||
}
|
||||
if (value) {
|
||||
|
|
@ -1442,7 +1451,7 @@ const load_settings_storage = async () => {
|
|||
element.checked = value === "true";
|
||||
break;
|
||||
case "select-one":
|
||||
element.selectedIndex = parseInt(value);
|
||||
element.value = value;
|
||||
break;
|
||||
case "text":
|
||||
case "number":
|
||||
|
|
@ -1683,7 +1692,7 @@ async function on_api() {
|
|||
console.error(e)
|
||||
// Redirect to show basic authenfication
|
||||
if (document.location.pathname == "/chat/") {
|
||||
document.location.href = `/chat/error`;
|
||||
//document.location.href = `/chat/error`;
|
||||
}
|
||||
}
|
||||
register_settings_storage();
|
||||
|
|
@ -1753,7 +1762,7 @@ async function load_version() {
|
|||
new_version = document.createElement("div");
|
||||
new_version.classList.add("new_version");
|
||||
const link = `<a href="${release_url}" target="_blank" title="${title}">v${versions["latest_version"]}</a>`;
|
||||
new_version.innerHTML = `g4f ${link} 🆕`;
|
||||
new_version.innerHTML = `G4F ${link} 🆕`;
|
||||
new_version.addEventListener("click", ()=>new_version.parentElement.removeChild(new_version));
|
||||
document.body.appendChild(new_version);
|
||||
} else {
|
||||
|
|
@ -1951,7 +1960,10 @@ async function api(ressource, args=null, files=null, message_id=null) {
|
|||
return read_response(response, message_id, args.provider || null);
|
||||
}
|
||||
response = await fetch(url, {headers: headers});
|
||||
return await response.json();
|
||||
if (response.status == 200) {
|
||||
return await response.json();
|
||||
}
|
||||
console.error(response);
|
||||
}
|
||||
|
||||
async function read_response(response, message_id, provider) {
|
||||
|
|
@ -1987,19 +1999,19 @@ function get_api_key_by_provider(provider) {
|
|||
return api_key;
|
||||
}
|
||||
|
||||
async function load_provider_models(providerIndex=null) {
|
||||
if (!providerIndex) {
|
||||
providerIndex = providerSelect.selectedIndex;
|
||||
async function load_provider_models(provider=null) {
|
||||
if (!provider) {
|
||||
provider = providerSelect.value;
|
||||
}
|
||||
modelProvider.innerHTML = '';
|
||||
const provider = providerSelect.options[providerIndex].value;
|
||||
modelProvider.name = `model[${provider}]`;
|
||||
if (!provider) {
|
||||
modelProvider.classList.add("hidden");
|
||||
modelSelect.classList.remove("hidden");
|
||||
return;
|
||||
}
|
||||
const models = await api('models', provider);
|
||||
if (models.length > 0) {
|
||||
if (models && models.length > 0) {
|
||||
modelSelect.classList.add("hidden");
|
||||
modelProvider.classList.remove("hidden");
|
||||
models.forEach((model) => {
|
||||
|
|
@ -2010,6 +2022,10 @@ async function load_provider_models(providerIndex=null) {
|
|||
option.selected = model.default;
|
||||
modelProvider.appendChild(option);
|
||||
});
|
||||
let value = appStorage.getItem(modelProvider.name);
|
||||
if (value) {
|
||||
modelProvider.value = value;
|
||||
}
|
||||
} else {
|
||||
modelProvider.classList.add("hidden");
|
||||
modelSelect.classList.remove("hidden");
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ class CopyButtonPlugin {
|
|||
el,
|
||||
text
|
||||
}) {
|
||||
if (el.classList.contains("language-plaintext")) {
|
||||
if (el.parentElement.tagName != "PRE") {
|
||||
return;
|
||||
}
|
||||
let button = Object.assign(document.createElement("button"), {
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ class Api:
|
|||
"parent": getattr(provider, "parent", None),
|
||||
"image": getattr(provider, "image_models", None) is not None,
|
||||
"vision": getattr(provider, "default_vision_model", None) is not None,
|
||||
"webdriver": "webdriver" in provider.get_parameters(),
|
||||
"auth": provider.needs_auth,
|
||||
} for provider in __providers__ if provider.working]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,4 +6,6 @@ def create_app() -> Flask:
|
|||
template_folder = os.path.join(sys._MEIPASS, "client")
|
||||
else:
|
||||
template_folder = "../client"
|
||||
return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
|
||||
app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
|
||||
app.config["TEMPLATES_AUTO_RELOAD"] = True # Enable auto reload in debug mode
|
||||
return app
|
||||
|
|
@ -14,9 +14,12 @@ from werkzeug.utils import secure_filename
|
|||
from ...image import is_allowed_extension, to_image
|
||||
from ...client.service import convert_to_provider
|
||||
from ...providers.asyncio import to_sync_generator
|
||||
from ...client.helper import filter_markdown
|
||||
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
|
||||
from ...tools.run_tools import iter_run_tools
|
||||
from ...errors import ProviderNotFoundError
|
||||
from ...cookies import get_cookies_dir
|
||||
from ... import ChatCompletion
|
||||
from .api import Api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -101,6 +104,44 @@ class Backend_Api(Api):
|
|||
}
|
||||
}
|
||||
|
||||
@app.route('/backend-api/v2/create', methods=['GET', 'POST'])
|
||||
def create():
|
||||
try:
|
||||
tool_calls = [{
|
||||
"function": {
|
||||
"name": "bucket_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}]
|
||||
web_search = request.args.get("web_search")
|
||||
if web_search:
|
||||
tool_calls.append({
|
||||
"function": {
|
||||
"name": "search_tool",
|
||||
"arguments": {"query": web_search, "instructions": ""} if web_search != "true" else {}
|
||||
},
|
||||
"type": "function"
|
||||
})
|
||||
do_filter_markdown = request.args.get("filter_markdown")
|
||||
response = iter_run_tools(
|
||||
ChatCompletion.create,
|
||||
model=request.args.get("model"),
|
||||
messages=[{"role": "user", "content": request.args.get("prompt")}],
|
||||
provider=request.args.get("provider", None),
|
||||
stream=not do_filter_markdown,
|
||||
ignore_stream=not request.args.get("stream"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
if do_filter_markdown:
|
||||
return Response(filter_markdown(response, do_filter_markdown), mimetype='text/plain')
|
||||
def cast_str():
|
||||
for chunk in response:
|
||||
yield str(chunk)
|
||||
return Response(cast_str(), mimetype='text/plain')
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
|
||||
|
||||
@app.route('/backend-api/v2/buckets', methods=['GET'])
|
||||
def list_buckets():
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class Website:
|
|||
self.app = app
|
||||
self.routes = {
|
||||
'/': {
|
||||
'function': redirect_home,
|
||||
'function': self._home,
|
||||
'methods': ['GET', 'POST']
|
||||
},
|
||||
'/chat/': {
|
||||
|
|
@ -41,3 +41,6 @@ class Website:
|
|||
|
||||
def _index(self):
|
||||
return render_template('index.html', chat_id=str(uuid.uuid4()))
|
||||
|
||||
def _home(self):
|
||||
return render_template('home.html')
|
||||
|
|
@ -98,26 +98,26 @@ gpt_35_turbo = Model(
|
|||
gpt_4 = Model(
|
||||
name = 'gpt-4',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = IterListProvider([DDG, Blackbox, ChatGptEs, PollinationsAI, Copilot, OpenaiChat, Liaobots, Airforce, Mhystical])
|
||||
best_provider = IterListProvider([DDG, Blackbox, ChatGptEs, PollinationsAI, Copilot, OpenaiChat, Liaobots, Mhystical])
|
||||
)
|
||||
|
||||
gpt_4_turbo = Model(
|
||||
name = 'gpt-4-turbo',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = Airforce
|
||||
best_provider = None
|
||||
)
|
||||
|
||||
# gpt-4o
|
||||
gpt_4o = Model(
|
||||
name = 'gpt-4o',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = IterListProvider([Blackbox, ChatGptEs, PollinationsAI, DarkAI, ChatGpt, Airforce, Liaobots, OpenaiChat])
|
||||
best_provider = IterListProvider([Blackbox, ChatGptEs, PollinationsAI, DarkAI, ChatGpt, Liaobots, OpenaiChat])
|
||||
)
|
||||
|
||||
gpt_4o_mini = Model(
|
||||
name = 'gpt-4o-mini',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = IterListProvider([DDG, ChatGptEs, Pizzagpt, ChatGpt, Airforce, RubiksAI, Liaobots, OpenaiChat])
|
||||
best_provider = IterListProvider([DDG, ChatGptEs, Pizzagpt, ChatGpt, RubiksAI, Liaobots, OpenaiChat])
|
||||
)
|
||||
|
||||
# o1
|
||||
|
|
@ -136,7 +136,7 @@ o1_preview = Model(
|
|||
o1_mini = Model(
|
||||
name = 'o1-mini',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = IterListProvider([Liaobots, Airforce])
|
||||
best_provider = IterListProvider([Liaobots])
|
||||
)
|
||||
|
||||
### GigaChat ###
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ import asyncio
|
|||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
from inspect import signature, Parameter
|
||||
from typing import Optional, _GenericAlias
|
||||
from typing import Optional, Awaitable, _GenericAlias
|
||||
from pathlib import Path
|
||||
try:
|
||||
from types import NoneType
|
||||
except ImportError:
|
||||
|
|
@ -15,9 +17,10 @@ except ImportError:
|
|||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .types import BaseProvider
|
||||
from .asyncio import get_running_loop, to_sync_generator
|
||||
from .response import BaseConversation
|
||||
from .response import BaseConversation, AuthResult
|
||||
from .helper import concat_chunks, async_concat_chunks
|
||||
from ..errors import ModelNotSupportedError
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
|
|
@ -33,15 +36,14 @@ SAFE_PARAMETERS = [
|
|||
]
|
||||
|
||||
BASIC_PARAMETERS = {
|
||||
"provider": None,
|
||||
"model": "",
|
||||
"messages": [],
|
||||
"provider": None,
|
||||
"stream": False,
|
||||
"timeout": 0,
|
||||
"response_format": None,
|
||||
"max_tokens": None,
|
||||
"stop": None,
|
||||
"web_search": False,
|
||||
}
|
||||
|
||||
PARAMETER_EXAMPLES = {
|
||||
|
|
@ -109,12 +111,8 @@ class AbstractProvider(BaseProvider):
|
|||
).parameters.items() if name in SAFE_PARAMETERS
|
||||
and (name != "stream" or cls.supports_stream)}
|
||||
if as_json:
|
||||
def get_type_as_var(annotation: type, key: str):
|
||||
if key == "model":
|
||||
return getattr(cls, "default_model", "")
|
||||
elif key == "stream":
|
||||
return cls.supports_stream
|
||||
elif key in PARAMETER_EXAMPLES:
|
||||
def get_type_as_var(annotation: type, key: str, default):
|
||||
if key in PARAMETER_EXAMPLES:
|
||||
if key == "messages" and not cls.supports_system_message:
|
||||
return [PARAMETER_EXAMPLES[key][-1]]
|
||||
return PARAMETER_EXAMPLES[key]
|
||||
|
|
@ -137,18 +135,21 @@ class AbstractProvider(BaseProvider):
|
|||
return {}
|
||||
elif annotation is None:
|
||||
return None
|
||||
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ is Optional:
|
||||
return get_type_as_var(annotation.__args__[0])
|
||||
elif annotation == "str" or annotation == "list[str]":
|
||||
return default
|
||||
elif isinstance(annotation, _GenericAlias):
|
||||
if annotation.__origin__ is Optional:
|
||||
return get_type_as_var(annotation.__args__[0])
|
||||
else:
|
||||
return str(annotation)
|
||||
return { name: (
|
||||
param.default
|
||||
if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
|
||||
else get_type_as_var(param.annotation if isinstance(param, Parameter) else type(param), name)
|
||||
else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
|
||||
) for name, param in {
|
||||
**BASIC_PARAMETERS,
|
||||
**{"provider": cls.__name__},
|
||||
**params
|
||||
**params,
|
||||
**{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
|
||||
}.items()}
|
||||
return params
|
||||
|
||||
|
|
@ -310,6 +311,12 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
create_authed = create_completion
|
||||
|
||||
create_authed_async = create_async
|
||||
|
||||
create_async_authed = create_async_generator
|
||||
|
||||
class ProviderModelMixin:
|
||||
default_model: str = None
|
||||
models: list[str] = []
|
||||
|
|
@ -335,3 +342,112 @@ class ProviderModelMixin:
|
|||
cls.last_model = model
|
||||
debug.last_model = model
|
||||
return model
|
||||
|
||||
class RaiseErrorMixin():
|
||||
|
||||
@staticmethod
|
||||
def raise_error(data: dict):
|
||||
if "error_message" in data:
|
||||
raise ResponseError(data["error_message"])
|
||||
elif "error" in data:
|
||||
if "code" in data["error"]:
|
||||
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
|
||||
elif "message" in data["error"]:
|
||||
raise ResponseError(data["error"]["message"])
|
||||
else:
|
||||
raise ResponseError(data["error"])
|
||||
|
||||
class AuthedMixin():
|
||||
|
||||
@classmethod
|
||||
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
|
||||
if "api_key" not in kwargs:
|
||||
raise MissingAuthError(f"API key is required for {cls.__name__}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def create_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
auth_result = {}
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return cls.create_completion(model, messages, **kwargs)
|
||||
finally:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
class AsyncAuthedMixin(AuthedMixin):
|
||||
@classmethod
|
||||
async def create_async_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return await cls.create_async(model, messages, **kwargs)
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
|
||||
|
||||
@classmethod
|
||||
async def create_async_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
@classmethod
|
||||
def create_async_authed_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> Awaitable[AsyncResult]:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
|
|
|||
|
|
@ -105,6 +105,10 @@ class Usage(ResponseType, JsonMixin):
|
|||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class AuthResult(JsonMixin):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class TitleGeneration(ResponseType):
|
||||
def __init__(self, title: str) -> None:
|
||||
self.title = title
|
||||
|
|
@ -182,4 +186,5 @@ class ImagePreview(ImageResponse):
|
|||
return super().__str__()
|
||||
|
||||
class Parameters(ResponseType, JsonMixin):
|
||||
pass
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
|
@ -40,7 +40,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
|
|||
)
|
||||
elif tool.get("function", {}).get("name") == "continue":
|
||||
last_line = messages[-1]["content"].strip().splitlines()[-1]
|
||||
content = f"Continue after this line.\n{last_line}"
|
||||
content = f"Carry on from this point:\n{last_line}"
|
||||
messages.append({"role": "user", "content": content})
|
||||
elif tool.get("function", {}).get("name") == "bucket_tool":
|
||||
def on_bucket(match):
|
||||
|
|
@ -90,7 +90,7 @@ def iter_run_tools(
|
|||
elif tool.get("function", {}).get("name") == "continue_tool":
|
||||
if provider not in ("OpenaiAccount", "HuggingFace"):
|
||||
last_line = messages[-1]["content"].strip().splitlines()[-1]
|
||||
content = f"Continue after this line:\n{last_line}"
|
||||
content = f"Carry on from this point:\n{last_line}"
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
# Enable provider native continue
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import json
|
|||
import hashlib
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from datetime import datetime
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
|
|
@ -65,7 +66,7 @@ class SearchResultEntry():
|
|||
def set_text(self, text: str):
|
||||
self.text = text
|
||||
|
||||
def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
|
||||
def scrape_text(html: str, max_words: int = None, add_source=True) -> Iterator[str]:
|
||||
source = BeautifulSoup(html, "html.parser")
|
||||
soup = source
|
||||
for selector in [
|
||||
|
|
@ -87,7 +88,7 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
|
|||
if select:
|
||||
select.extract()
|
||||
|
||||
for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"):
|
||||
for paragraph in soup.select("p, table:not(:has(p)), ul:not(:has(p)), h1, h2, h3, h4, h5, h6"):
|
||||
for line in paragraph.text.splitlines():
|
||||
words = [word for word in line.replace("\t", " ").split(" ") if word]
|
||||
count = len(words)
|
||||
|
|
@ -99,24 +100,25 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
|
|||
break
|
||||
yield " ".join(words) + "\n"
|
||||
|
||||
canonical_link = source.find("link", rel="canonical")
|
||||
if canonical_link and "href" in canonical_link.attrs:
|
||||
link = canonical_link["href"]
|
||||
domain = urlparse(link).netloc
|
||||
yield f"\nSource: [{domain}]({link})"
|
||||
if add_source:
|
||||
canonical_link = source.find("link", rel="canonical")
|
||||
if canonical_link and "href" in canonical_link.attrs:
|
||||
link = canonical_link["href"]
|
||||
domain = urlparse(link).netloc
|
||||
yield f"\nSource: [{domain}]({link})"
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None, add_source: bool = False) -> str:
|
||||
try:
|
||||
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
|
||||
bucket_dir.mkdir(parents=True, exist_ok=True)
|
||||
md5_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
cache_file = bucket_dir / f"{url.split('/')[3]}.{datetime.date.today()}.{md5_hash}.txt"
|
||||
cache_file = bucket_dir / f"{url.split('?')[0].split('//')[1].replace('/', '+')[:16]}.{datetime.date.today()}.{md5_hash}.txt"
|
||||
if cache_file.exists():
|
||||
return cache_file.read_text()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text()
|
||||
text = "".join(scrape_text(html, max_words))
|
||||
text = "".join(scrape_text(html, max_words, add_source))
|
||||
with open(cache_file, "w") as f:
|
||||
f.write(text)
|
||||
return text
|
||||
|
|
@ -136,6 +138,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
|
|||
max_results=max_results,
|
||||
backend=backend,
|
||||
):
|
||||
if ".google.com" in result["href"]:
|
||||
continue
|
||||
results.append(SearchResultEntry(
|
||||
result["title"],
|
||||
result["href"],
|
||||
|
|
@ -146,7 +150,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
|
|||
requests = []
|
||||
async with ClientSession(timeout=ClientTimeout(timeout)) as session:
|
||||
for entry in results:
|
||||
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1))))
|
||||
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)), False))
|
||||
texts = await asyncio.gather(*requests)
|
||||
|
||||
formatted_results = []
|
||||
|
|
@ -173,7 +177,7 @@ async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_
|
|||
query = spacy_get_keywords(prompt)
|
||||
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode()
|
||||
md5_hash = hashlib.md5(json_bytes).hexdigest()
|
||||
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search"
|
||||
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search:{datetime.date.today()}"
|
||||
bucket_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt"
|
||||
if cache_file.exists():
|
||||
|
|
@ -192,7 +196,9 @@ Instruction: {instructions}
|
|||
User request:
|
||||
{prompt}
|
||||
"""
|
||||
debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words")
|
||||
debug.log(f"Web search: '{query.strip()[:50]}...'")
|
||||
if isinstance(search_results, SearchResults):
|
||||
debug.log(f"with {len(search_results.results)} Results {search_results.used_words} Words")
|
||||
return new_prompt
|
||||
|
||||
def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue