mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Improve tools support in OpenaiTemplate and GeminiPro
Update models in DDG, PerplexityLabs, Gemini Fix issues with curl_cffi in new versions
This commit is contained in:
parent
c3ed6d0f8f
commit
e53483d85b
33 changed files with 300 additions and 172 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -66,3 +66,4 @@ bench.py
|
||||||
to-reverse.txt
|
to-reverse.txt
|
||||||
g4f/Provider/OpenaiChat2.py
|
g4f/Provider/OpenaiChat2.py
|
||||||
generated_images/
|
generated_images/
|
||||||
|
projects/windows/
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import g4f.debug
|
||||||
|
|
||||||
|
g4f.debug.version_check = False
|
||||||
|
|
||||||
from .asyncio import *
|
from .asyncio import *
|
||||||
from .backend import *
|
from .backend import *
|
||||||
from .main import *
|
from .main import *
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,12 @@ from g4f.errors import VersionNotFoundError
|
||||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||||
|
|
||||||
class TestGetLastProvider(unittest.TestCase):
|
class TestGetLastProvider(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_latest_version(self):
|
def test_get_latest_version(self):
|
||||||
current_version = g4f.version.utils.current_version
|
current_version = g4f.version.utils.current_version
|
||||||
if current_version is not None:
|
if current_version is not None:
|
||||||
self.assertIsInstance(g4f.version.utils.current_version, str)
|
self.assertIsInstance(g4f.version.utils.current_version, str)
|
||||||
self.assertIsInstance(g4f.version.utils.latest_version, str)
|
try:
|
||||||
|
self.assertIsInstance(g4f.version.utils.latest_version, str)
|
||||||
|
except VersionNotFoundError:
|
||||||
|
pass
|
||||||
|
|
@ -42,7 +42,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"gpt-4": "gpt-4o-mini",
|
"gpt-4": "gpt-4o-mini",
|
||||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||||
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||||
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
"mixtral-8x7b": "mistralai/Mistral-Small-24B-Instruct-2501",
|
||||||
}
|
}
|
||||||
|
|
||||||
last_request_time = 0
|
last_request_time = 0
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ import json
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
from ..requests import StreamSession, raise_for_status
|
from ..requests import StreamSession, raise_for_status
|
||||||
from ..providers.response import FinishReason
|
from ..errors import ResponseError
|
||||||
|
from ..providers.response import FinishReason, Sources
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
|
||||||
API_URL = "https://www.perplexity.ai/socket.io/"
|
API_URL = "https://www.perplexity.ai/socket.io/"
|
||||||
|
|
@ -15,10 +16,11 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://labs.perplexity.ai"
|
url = "https://labs.perplexity.ai"
|
||||||
working = True
|
working = True
|
||||||
|
|
||||||
default_model = "sonar-pro"
|
default_model = "r1-1776"
|
||||||
models = [
|
models = [
|
||||||
"sonar",
|
|
||||||
default_model,
|
default_model,
|
||||||
|
"sonar-pro",
|
||||||
|
"sonar",
|
||||||
"sonar-reasoning",
|
"sonar-reasoning",
|
||||||
"sonar-reasoning-pro",
|
"sonar-reasoning-pro",
|
||||||
]
|
]
|
||||||
|
|
@ -32,19 +34,10 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0",
|
|
||||||
"Accept": "*/*",
|
|
||||||
"Accept-Language": "de,en-US;q=0.7,en;q=0.3",
|
|
||||||
"Accept-Encoding": "gzip, deflate, br",
|
|
||||||
"Origin": cls.url,
|
"Origin": cls.url,
|
||||||
"Connection": "keep-alive",
|
|
||||||
"Referer": f"{cls.url}/",
|
"Referer": f"{cls.url}/",
|
||||||
"Sec-Fetch-Dest": "empty",
|
|
||||||
"Sec-Fetch-Mode": "cors",
|
|
||||||
"Sec-Fetch-Site": "same-site",
|
|
||||||
"TE": "trailers",
|
|
||||||
}
|
}
|
||||||
async with StreamSession(headers=headers, proxies={"all": proxy}) as session:
|
async with StreamSession(headers=headers, proxy=proxy, impersonate="chrome") as session:
|
||||||
t = format(random.getrandbits(32), "08x")
|
t = format(random.getrandbits(32), "08x")
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{API_URL}?EIO=4&transport=polling&t={t}"
|
f"{API_URL}?EIO=4&transport=polling&t={t}"
|
||||||
|
|
@ -60,17 +53,22 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
) as response:
|
) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
assert await response.text() == "OK"
|
assert await response.text() == "OK"
|
||||||
|
async with session.get(
|
||||||
|
f"{API_URL}?EIO=4&transport=polling&t={t}&sid={sid}",
|
||||||
|
data=post_data
|
||||||
|
) as response:
|
||||||
|
await raise_for_status(response)
|
||||||
|
assert (await response.text()).startswith("40")
|
||||||
async with session.ws_connect(f"{WS_URL}?EIO=4&transport=websocket&sid={sid}", autoping=False) as ws:
|
async with session.ws_connect(f"{WS_URL}?EIO=4&transport=websocket&sid={sid}", autoping=False) as ws:
|
||||||
await ws.send_str("2probe")
|
await ws.send_str("2probe")
|
||||||
assert(await ws.receive_str() == "3probe")
|
assert(await ws.receive_str() == "3probe")
|
||||||
await ws.send_str("5")
|
await ws.send_str("5")
|
||||||
assert(await ws.receive_str())
|
|
||||||
assert(await ws.receive_str() == "6")
|
assert(await ws.receive_str() == "6")
|
||||||
message_data = {
|
message_data = {
|
||||||
"version": "2.16",
|
"version": "2.18",
|
||||||
"source": "default",
|
"source": "default",
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages
|
"messages": messages,
|
||||||
}
|
}
|
||||||
await ws.send_str("42" + json.dumps(["perplexity_labs", message_data]))
|
await ws.send_str("42" + json.dumps(["perplexity_labs", message_data]))
|
||||||
last_message = 0
|
last_message = 0
|
||||||
|
|
@ -82,12 +80,15 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
await ws.send_str("3")
|
await ws.send_str("3")
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
if last_message == 0 and model == cls.default_model:
|
||||||
|
yield "<think>"
|
||||||
data = json.loads(message[2:])[1]
|
data = json.loads(message[2:])[1]
|
||||||
yield data["output"][last_message:]
|
yield data["output"][last_message:]
|
||||||
last_message = len(data["output"])
|
last_message = len(data["output"])
|
||||||
if data["final"]:
|
if data["final"]:
|
||||||
|
if data["citations"]:
|
||||||
|
yield Sources(data["citations"])
|
||||||
yield FinishReason("stop")
|
yield FinishReason("stop")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing message: {message} - {e}")
|
raise ResponseError(f"Message: {message}") from e
|
||||||
raise RuntimeError(f"Message: {message}") from e
|
|
||||||
|
|
|
||||||
|
|
@ -122,9 +122,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
except ModelNotFoundError:
|
except ModelNotFoundError:
|
||||||
if model not in cls.image_models:
|
if model not in cls.image_models:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if not cache and seed is None:
|
if not cache and seed is None:
|
||||||
seed = random.randint(0, 10000)
|
seed = random.randint(1000, 999999)
|
||||||
|
|
||||||
if model in cls.image_models:
|
if model in cls.image_models:
|
||||||
async for chunk in cls._generate_image(
|
async for chunk in cls._generate_image(
|
||||||
|
|
@ -182,9 +182,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
}
|
}
|
||||||
params = {k: v for k, v in params.items() if v is not None}
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
|
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
|
||||||
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
|
prefix = f"{model}_{seed}" if seed is not None else model
|
||||||
|
url = f"{cls.image_api_endpoint}prompt/{prefix}_{quote_plus(prompt)}?{query}"
|
||||||
yield ImagePreview(url, prompt)
|
yield ImagePreview(url, prompt)
|
||||||
|
|
||||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||||
async with session.get(url, allow_redirects=True) as response:
|
async with session.get(url, allow_redirects=True) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
|
|
||||||
|
|
@ -39,14 +39,15 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
default_model = default_model
|
default_model = default_model
|
||||||
model_aliases = model_aliases
|
model_aliases = model_aliases
|
||||||
image_models = image_models
|
image_models = image_models
|
||||||
|
text_models = fallback_models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls):
|
def get_models(cls):
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
try:
|
try:
|
||||||
text = requests.get(cls.url).text
|
text = requests.get(cls.url).text
|
||||||
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
|
||||||
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
|
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
|
||||||
|
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
||||||
text = text.replace('void 0', 'null')
|
text = text.replace('void 0', 'null')
|
||||||
def add_quotation_mark(match):
|
def add_quotation_mark(match):
|
||||||
return f'{match.group(1)}"{match.group(2)}":'
|
return f'{match.group(1)}"{match.group(2)}":'
|
||||||
|
|
@ -56,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
cls.models = cls.text_models + cls.image_models
|
cls.models = cls.text_models + cls.image_models
|
||||||
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
|
debug.error(f"{cls.__name__}: Error reading models: {type(e).__name__}: {e}")
|
||||||
cls.models = [*fallback_models]
|
cls.models = [*fallback_models]
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
|
||||||
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
|
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...providers.response import FinishReason, ImageResponse
|
from ...providers.response import FinishReason, ImageResponse
|
||||||
from ..helper import format_image_prompt
|
from ..helper import format_image_prompt, get_last_user_message
|
||||||
from .models import default_model, default_image_model, model_aliases, fallback_models
|
from .models import default_model, default_image_model, model_aliases, fallback_models, image_models
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
|
@ -22,26 +22,26 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
default_model = default_model
|
default_model = default_model
|
||||||
default_image_model = default_image_model
|
default_image_model = default_image_model
|
||||||
model_aliases = model_aliases
|
model_aliases = model_aliases
|
||||||
|
image_models = image_models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls) -> list[str]:
|
def get_models(cls) -> list[str]:
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
models = fallback_models.copy()
|
models = fallback_models.copy()
|
||||||
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
|
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
response.raise_for_status()
|
if response.ok:
|
||||||
extra_models = [model["id"] for model in response.json()]
|
extra_models = [model["id"] for model in response.json()]
|
||||||
extra_models.sort()
|
extra_models.sort()
|
||||||
models.extend([model for model in extra_models if model not in models])
|
models.extend([model for model in extra_models if model not in models])
|
||||||
if not cls.image_models:
|
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
|
||||||
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
|
response = requests.get(url)
|
||||||
response = requests.get(url)
|
if response.ok:
|
||||||
response.raise_for_status()
|
cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
|
||||||
cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
|
cls.image_models.sort()
|
||||||
cls.image_models.sort()
|
models.extend([model for model in cls.image_models if model not in models])
|
||||||
models.extend([model for model in cls.image_models if model not in models])
|
cls.models = models
|
||||||
cls.models = models
|
return cls.models
|
||||||
return cls.models
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
|
|
@ -57,6 +57,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
action: str = None,
|
action: str = None,
|
||||||
extra_data: dict = {},
|
extra_data: dict = {},
|
||||||
|
seed: int = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
try:
|
try:
|
||||||
|
|
@ -104,7 +105,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if pipeline_tag == "text-to-image":
|
if pipeline_tag == "text-to-image":
|
||||||
stream = False
|
stream = False
|
||||||
inputs = format_image_prompt(messages, prompt)
|
inputs = format_image_prompt(messages, prompt)
|
||||||
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
|
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **extra_data}}
|
||||||
elif pipeline_tag in ("text-generation", "image-text-to-text"):
|
elif pipeline_tag in ("text-generation", "image-text-to-text"):
|
||||||
model_type = None
|
model_type = None
|
||||||
if "config" in model_data and "model_type" in model_data["config"]:
|
if "config" in model_data and "model_type" in model_data["config"]:
|
||||||
|
|
@ -116,11 +117,13 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if len(messages) > 6:
|
if len(messages) > 6:
|
||||||
messages = messages[:3] + messages[-3:]
|
messages = messages[:3] + messages[-3:]
|
||||||
else:
|
else:
|
||||||
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
|
messages = [m for m in messages if m["role"] == "system"] + [get_last_user_message(messages)]
|
||||||
inputs = get_inputs(messages, model_data, model_type, do_continue)
|
inputs = get_inputs(messages, model_data, model_type, do_continue)
|
||||||
debug.log(f"New len: {len(inputs)}")
|
debug.log(f"New len: {len(inputs)}")
|
||||||
if model_type == "gpt2" and max_tokens >= 1024:
|
if model_type == "gpt2" and max_tokens >= 1024:
|
||||||
params["max_new_tokens"] = 512
|
params["max_new_tokens"] = 512
|
||||||
|
if seed is not None:
|
||||||
|
params["seed"] = seed
|
||||||
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
||||||
else:
|
else:
|
||||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if is_started:
|
if is_started:
|
||||||
raise e
|
raise e
|
||||||
debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
|
debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
|
||||||
if not cls.image_models:
|
if not cls.image_models:
|
||||||
cls.get_models()
|
cls.get_models()
|
||||||
if model in cls.image_models:
|
if model in cls.image_models:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import re
|
import re
|
||||||
import time
|
import random
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
|
|
@ -88,7 +88,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
||||||
prompt = format_image_prompt(messages, prompt)
|
prompt = format_image_prompt(messages, prompt)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = int(time.time())
|
seed = random.randint(1000, 999999)
|
||||||
|
|
||||||
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
||||||
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
|
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,7 @@ class CopilotAccount(AsyncAuthedProvider, Copilot):
|
||||||
except NoValidHarFileError as h:
|
except NoValidHarFileError as h:
|
||||||
debug.log(f"Copilot: {h}")
|
debug.log(f"Copilot: {h}")
|
||||||
if has_nodriver:
|
if has_nodriver:
|
||||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
||||||
if login_url:
|
|
||||||
yield RequestLogin(cls.label, login_url)
|
|
||||||
Copilot._access_token, Copilot._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
Copilot._access_token, Copilot._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||||
else:
|
else:
|
||||||
raise h
|
raise h
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
default_image_model = default_model
|
default_image_model = default_model
|
||||||
default_vision_model = default_model
|
default_vision_model = default_model
|
||||||
image_models = [default_image_model]
|
image_models = [default_image_model]
|
||||||
models = [default_model, "gemini-1.5-flash", "gemini-1.5-pro"]
|
models = [default_model, "gemini-2.0"]
|
||||||
|
|
||||||
synthesize_content_type = "audio/vnd.wav"
|
synthesize_content_type = "audio/vnd.wav"
|
||||||
|
|
||||||
|
|
@ -179,7 +179,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
|
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
|
||||||
content = response_part[4][0][1][0]
|
content = response_part[4][0][1][0]
|
||||||
except (ValueError, KeyError, TypeError, IndexError) as e:
|
except (ValueError, KeyError, TypeError, IndexError) as e:
|
||||||
debug.log(f"{cls.__name__}:{e.__class__.__name__}:{e}")
|
debug.error(f"{cls.__name__} {type(e).__name__}: {e}")
|
||||||
continue
|
continue
|
||||||
match = re.search(r'\[Imagen of (.*?)\]', content)
|
match = re.search(r'\[Imagen of (.*?)\]', content)
|
||||||
if match:
|
if match:
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,9 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
]
|
]
|
||||||
cls.models.sort()
|
cls.models.sort()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(e)
|
debug.error(e)
|
||||||
|
if api_key is not None:
|
||||||
|
raise MissingAuthError("Invalid API key")
|
||||||
return cls.fallback_models
|
return cls.fallback_models
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
|
|
@ -111,8 +113,18 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"topK": kwargs.get("top_k"),
|
"topK": kwargs.get("top_k"),
|
||||||
},
|
},
|
||||||
"tools": [{
|
"tools": [{
|
||||||
"functionDeclarations": tools
|
"function_declarations": [{
|
||||||
}] if tools else None
|
"name": tool["function"]["name"],
|
||||||
|
"description": tool["function"]["description"],
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {key: {
|
||||||
|
"type": value["type"],
|
||||||
|
"description": value["title"]
|
||||||
|
} for key, value in tool["function"]["parameters"]["properties"].items()}
|
||||||
|
},
|
||||||
|
} for tool in tools]
|
||||||
|
}] if tools else None
|
||||||
}
|
}
|
||||||
system_prompt = "\n".join(
|
system_prompt = "\n".join(
|
||||||
message["content"]
|
message["content"]
|
||||||
|
|
|
||||||
|
|
@ -322,8 +322,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
try:
|
try:
|
||||||
image_requests = await cls.upload_images(session, auth_result, images) if images else None
|
image_requests = await cls.upload_images(session, auth_result, images) if images else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log("OpenaiChat: Upload image failed")
|
debug.error("OpenaiChat: Upload image failed")
|
||||||
debug.log(f"{e.__class__.__name__}: {e}")
|
debug.error(e)
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
if conversation is None:
|
if conversation is None:
|
||||||
conversation = Conversation(conversation_id, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
conversation = Conversation(conversation_id, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
||||||
|
|
@ -360,12 +360,14 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
# if auth_result.arkose_token is None:
|
# if auth_result.arkose_token is None:
|
||||||
# raise MissingAuthError("No arkose token found in .har file")
|
# raise MissingAuthError("No arkose token found in .har file")
|
||||||
if "proofofwork" in chat_requirements:
|
if "proofofwork" in chat_requirements:
|
||||||
if getattr(auth_result, "proof_token", None) is None:
|
user_agent = getattr(auth_result, "headers", {}).get("user-agent")
|
||||||
auth_result.proof_token = get_config(auth_result.headers.get("user-agent"))
|
proof_token = getattr(auth_result, "proof_token", None)
|
||||||
|
if proof_token is None:
|
||||||
|
auth_result.proof_token = get_config(user_agent)
|
||||||
proofofwork = generate_proof_token(
|
proofofwork = generate_proof_token(
|
||||||
**chat_requirements["proofofwork"],
|
**chat_requirements["proofofwork"],
|
||||||
user_agent=getattr(auth_result, "headers", {}).get("user-agent"),
|
user_agent=user_agent,
|
||||||
proof_token=getattr(auth_result, "proof_token", None)
|
proof_token=proof_token
|
||||||
)
|
)
|
||||||
[debug.log(text) for text in (
|
[debug.log(text) for text in (
|
||||||
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
||||||
|
|
@ -425,8 +427,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
) as response:
|
) as response:
|
||||||
cls._update_request_args(auth_result, session)
|
cls._update_request_args(auth_result, session)
|
||||||
if response.status == 403:
|
if response.status == 403:
|
||||||
auth_result.proof_token = None
|
|
||||||
cls.request_config.proof_token = None
|
cls.request_config.proof_token = None
|
||||||
|
raise MissingAuthError("Access token is not valid")
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
buffer = u""
|
buffer = u""
|
||||||
async for line in response.iter_lines():
|
async for line in response.iter_lines():
|
||||||
|
|
@ -469,14 +471,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
yield Parameters(**{
|
|
||||||
"action": "continue" if conversation.finish_reason == "max_tokens" else "variant",
|
|
||||||
"conversation": conversation.get_dict(),
|
|
||||||
"proof_token": cls.request_config.proof_token,
|
|
||||||
"cookies": cls._cookies,
|
|
||||||
"headers": cls._headers,
|
|
||||||
"web_search": web_search,
|
|
||||||
})
|
|
||||||
yield FinishReason(conversation.finish_reason)
|
yield FinishReason(conversation.finish_reason)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
if cls.sort_models:
|
if cls.sort_models:
|
||||||
cls.models.sort()
|
cls.models.sort()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(e)
|
debug.error(e)
|
||||||
return cls.fallback_models
|
return cls.fallback_models
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
headers: dict = None,
|
headers: dict = None,
|
||||||
impersonate: str = None,
|
impersonate: str = None,
|
||||||
tools: Optional[list] = None,
|
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "", "reasoning_effort", "logit_bias"],
|
||||||
extra_data: dict = {},
|
extra_data: dict = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
|
|
@ -112,6 +112,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
messages[-1] = last_message
|
messages[-1] = last_message
|
||||||
|
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||||
data = filter_none(
|
data = filter_none(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -120,7 +121,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
tools=tools,
|
**extra_parameters,
|
||||||
**extra_data
|
**extra_data
|
||||||
)
|
)
|
||||||
if api_endpoint is None:
|
if api_endpoint is None:
|
||||||
|
|
|
||||||
|
|
@ -588,7 +588,7 @@ class Api:
|
||||||
target=target)
|
target=target)
|
||||||
debug.log(f"Image copied from {source_url}")
|
debug.log(f"Image copied from {source_url}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"{type(e).__name__}: Download failed: {source_url}\n{e}")
|
debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
|
||||||
return RedirectResponse(url=source_url)
|
return RedirectResponse(url=source_url)
|
||||||
if not os.path.isfile(target):
|
if not os.path.isfile(target):
|
||||||
return ErrorResponse.from_message("File not found", HTTP_404_NOT_FOUND)
|
return ErrorResponse.from_message("File not found", HTTP_404_NOT_FOUND)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from ..providers.retry_provider import IterListProvider
|
||||||
from ..providers.asyncio import to_sync_generator
|
from ..providers.asyncio import to_sync_generator
|
||||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
||||||
from .image_models import ImageModels
|
from .image_models import ImageModels
|
||||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||||
from .service import get_model_and_provider, convert_to_provider
|
from .service import get_model_and_provider, convert_to_provider
|
||||||
|
|
@ -103,14 +103,14 @@ def iter_response(
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
if usage is None:
|
if usage is None:
|
||||||
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
|
usage = Usage(completion_tokens=idx, total_tokens=idx)
|
||||||
|
|
||||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
yield ChatCompletionChunk.model_construct(
|
yield ChatCompletionChunk.model_construct(
|
||||||
None, finish_reason, completion_id, int(time.time()),
|
None, finish_reason, completion_id, int(time.time()),
|
||||||
usage=usage.get_dict()
|
usage=usage
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if response_format is not None and "type" in response_format:
|
if response_format is not None and "type" in response_format:
|
||||||
|
|
@ -118,7 +118,8 @@ def iter_response(
|
||||||
content = filter_json(content)
|
content = filter_json(content)
|
||||||
yield ChatCompletion.model_construct(
|
yield ChatCompletion.model_construct(
|
||||||
content, finish_reason, completion_id, int(time.time()),
|
content, finish_reason, completion_id, int(time.time()),
|
||||||
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
|
usage=UsageModel.model_construct(**usage.get_dict()),
|
||||||
|
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Synchronous iter_append_model_and_provider function
|
# Synchronous iter_append_model_and_provider function
|
||||||
|
|
@ -186,7 +187,7 @@ async def async_iter_response(
|
||||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||||
|
|
||||||
if usage is None:
|
if usage is None:
|
||||||
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
|
usage = Usage(completion_tokens=idx, total_tokens=idx)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
yield ChatCompletionChunk.model_construct(
|
yield ChatCompletionChunk.model_construct(
|
||||||
|
|
@ -199,7 +200,8 @@ async def async_iter_response(
|
||||||
content = filter_json(content)
|
content = filter_json(content)
|
||||||
yield ChatCompletion.model_construct(
|
yield ChatCompletion.model_construct(
|
||||||
content, finish_reason, completion_id, int(time.time()),
|
content, finish_reason, completion_id, int(time.time()),
|
||||||
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
|
usage=UsageModel.model_construct(**usage.get_dict()),
|
||||||
|
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
await safe_aclose(response)
|
await safe_aclose(response)
|
||||||
|
|
@ -363,7 +365,7 @@ class Images:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = e
|
error = e
|
||||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
|
||||||
else:
|
else:
|
||||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||||
|
|
||||||
|
|
@ -458,7 +460,7 @@ class Images:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = e
|
error = e
|
||||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
|
||||||
else:
|
else:
|
||||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||||
|
|
||||||
|
|
@ -583,7 +585,7 @@ class AsyncCompletions:
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
model: str,
|
model: str,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
|
) -> AsyncIterator[ChatCompletionChunk]:
|
||||||
return self.create(messages, model, stream=True, **kwargs)
|
return self.create(messages, model, stream=True, **kwargs)
|
||||||
|
|
||||||
class AsyncImages(Images):
|
class AsyncImages(Images):
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,6 @@ from time import time
|
||||||
|
|
||||||
from .helper import filter_none
|
from .helper import filter_none
|
||||||
|
|
||||||
ToolCalls = Optional[List[Dict[str, Any]]]
|
|
||||||
Usage = Optional[Dict[str, int]]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -29,6 +26,40 @@ class BaseModel(BaseModel):
|
||||||
return super().model_construct(**data)
|
return super().model_construct(**data)
|
||||||
return cls.construct(**data)
|
return cls.construct(**data)
|
||||||
|
|
||||||
|
class UsageModel(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens_details: Optional[Dict[str, Any]]
|
||||||
|
completion_tokens_details: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_construct(cls, prompt_tokens=0, completion_tokens=0, total_tokens=0, prompt_tokens_details=None, completion_tokens_details=None, **kwargs):
|
||||||
|
return super().model_construct(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
|
completion_tokens_details=completion_tokens_details,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
class ToolFunctionModel(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
class ToolCallModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: ToolFunctionModel
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_construct(cls, function=None, **kwargs):
|
||||||
|
return super().model_construct(
|
||||||
|
**kwargs,
|
||||||
|
function=ToolFunctionModel.model_construct(**function),
|
||||||
|
)
|
||||||
|
|
||||||
class ChatCompletionChunk(BaseModel):
|
class ChatCompletionChunk(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str
|
object: str
|
||||||
|
|
@ -36,7 +67,7 @@ class ChatCompletionChunk(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
provider: Optional[str]
|
provider: Optional[str]
|
||||||
choices: List[ChatCompletionDeltaChoice]
|
choices: List[ChatCompletionDeltaChoice]
|
||||||
usage: Usage
|
usage: UsageModel
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def model_construct(
|
def model_construct(
|
||||||
|
|
@ -45,7 +76,7 @@ class ChatCompletionChunk(BaseModel):
|
||||||
finish_reason: str,
|
finish_reason: str,
|
||||||
completion_id: str = None,
|
completion_id: str = None,
|
||||||
created: int = None,
|
created: int = None,
|
||||||
usage: Usage = None
|
usage: UsageModel = None
|
||||||
):
|
):
|
||||||
return super().model_construct(
|
return super().model_construct(
|
||||||
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
||||||
|
|
@ -63,10 +94,10 @@ class ChatCompletionChunk(BaseModel):
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
tool_calls: ToolCalls
|
tool_calls: list[ToolCallModel] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def model_construct(cls, content: str, tool_calls: ToolCalls = None):
|
def model_construct(cls, content: str, tool_calls: list = None):
|
||||||
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
|
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
|
||||||
|
|
||||||
class ChatCompletionChoice(BaseModel):
|
class ChatCompletionChoice(BaseModel):
|
||||||
|
|
@ -85,11 +116,7 @@ class ChatCompletion(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
provider: Optional[str]
|
provider: Optional[str]
|
||||||
choices: List[ChatCompletionChoice]
|
choices: List[ChatCompletionChoice]
|
||||||
usage: Usage = Field(default={
|
usage: UsageModel
|
||||||
"prompt_tokens": 0, #prompt_tokens,
|
|
||||||
"completion_tokens": 0, #completion_tokens,
|
|
||||||
"total_tokens": 0, #prompt_tokens + completion_tokens,
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def model_construct(
|
def model_construct(
|
||||||
|
|
@ -98,8 +125,8 @@ class ChatCompletion(BaseModel):
|
||||||
finish_reason: str,
|
finish_reason: str,
|
||||||
completion_id: str = None,
|
completion_id: str = None,
|
||||||
created: int = None,
|
created: int = None,
|
||||||
tool_calls: ToolCalls = None,
|
tool_calls: list[ToolCallModel] = None,
|
||||||
usage: Usage = None
|
usage: UsageModel = None
|
||||||
):
|
):
|
||||||
return super().model_construct(
|
return super().model_construct(
|
||||||
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sys
|
||||||
from .providers.types import ProviderType
|
from .providers.types import ProviderType
|
||||||
|
|
||||||
logging: bool = False
|
logging: bool = False
|
||||||
|
|
@ -8,6 +9,9 @@ version: str = None
|
||||||
log_handler: callable = print
|
log_handler: callable = print
|
||||||
logs: list = []
|
logs: list = []
|
||||||
|
|
||||||
def log(text):
|
def log(text, file = None):
|
||||||
if logging:
|
if logging:
|
||||||
log_handler(text)
|
log_handler(text, file=file)
|
||||||
|
|
||||||
|
def error(error, name: str = None):
|
||||||
|
log(error if isinstance(error, str) else f"{type(error).__name__ if name is None else name}: {error}", file=sys.stderr)
|
||||||
|
|
@ -181,7 +181,12 @@
|
||||||
<script>
|
<script>
|
||||||
(async () => {
|
(async () => {
|
||||||
const isIframe = window.self !== window.top;
|
const isIframe = window.self !== window.top;
|
||||||
|
const backendUrl = "{{backend_url}}";
|
||||||
let url = new URL(window.location.href)
|
let url = new URL(window.location.href)
|
||||||
|
if (isIframe && backendUrl) {
|
||||||
|
window.location.replace(url.search ? `${backendUrl}?${url.search}` : backendUrl);
|
||||||
|
return;
|
||||||
|
}
|
||||||
let params = new URLSearchParams(url.search);
|
let params = new URLSearchParams(url.search);
|
||||||
if (params.get("__sign")) {
|
if (params.get("__sign")) {
|
||||||
localStorage.setItem("zerogpu_token", params.get("__sign"));
|
localStorage.setItem("zerogpu_token", params.get("__sign"));
|
||||||
|
|
@ -232,12 +237,11 @@
|
||||||
import * as hub from "@huggingface/hub";
|
import * as hub from "@huggingface/hub";
|
||||||
import { init } from "@huggingface/space-header";
|
import { init } from "@huggingface/space-header";
|
||||||
|
|
||||||
const isIframe = window.self !== window.top;
|
|
||||||
const button = document.querySelector('form a.button');
|
const button = document.querySelector('form a.button');
|
||||||
if (isIframe) {
|
if (isIframe) {
|
||||||
button.classList.remove('hidden');
|
button.classList.remove('hidden');
|
||||||
} else {
|
} else {
|
||||||
init("roxky/g4f-space");
|
init("roxky/g4f-space-new");
|
||||||
}
|
}
|
||||||
|
|
||||||
const form = document.querySelector("form");
|
const form = document.querySelector("form");
|
||||||
|
|
@ -282,13 +286,11 @@
|
||||||
const cache_id = Math.floor(Math.random() * max);
|
const cache_id = Math.floor(Math.random() * max);
|
||||||
let prompt;
|
let prompt;
|
||||||
if (cache_id % 2 == 0) {
|
if (cache_id % 2 == 0) {
|
||||||
prompt = `
|
prompt = `Today is ${new Date().toJSON().slice(0, 10)}.
|
||||||
Today is ${new Date().toJSON().slice(0, 10)}.
|
|
||||||
Create a single-page HTML screensaver reflecting the current season (based on the date).
|
Create a single-page HTML screensaver reflecting the current season (based on the date).
|
||||||
Avoid using any text.`;
|
Avoid using any text.`;
|
||||||
} else {
|
} else {
|
||||||
prompt = `Create a single-page HTML screensaver. Avoid using any text.`;
|
prompt = `Create a single-page HTML screensaver. Avoid using any text.`;
|
||||||
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
|
|
||||||
}
|
}
|
||||||
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
|
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
|
||||||
const text = await response.text()
|
const text = await response.text()
|
||||||
|
|
|
||||||
|
|
@ -293,7 +293,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="field">
|
<div class="field">
|
||||||
<select name="model" id="model">
|
<select name="model" id="model">
|
||||||
<option value="">Model: Default</option>
|
<option value="" selected="selected">Model: Default</option>
|
||||||
<option value="gpt-4">gpt-4</option>
|
<option value="gpt-4">gpt-4</option>
|
||||||
<option value="gpt-4o">gpt-4o</option>
|
<option value="gpt-4o">gpt-4o</option>
|
||||||
<option value="gpt-4o-mini">gpt-4o-mini</option>
|
<option value="gpt-4o-mini">gpt-4o-mini</option>
|
||||||
|
|
|
||||||
|
|
@ -72,13 +72,16 @@ class Api:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_version() -> dict:
|
def get_version() -> dict:
|
||||||
|
current_version = None
|
||||||
|
latest_version = None
|
||||||
try:
|
try:
|
||||||
current_version = version.utils.current_version
|
current_version = version.utils.current_version
|
||||||
|
latest_version = version.utils.latest_version
|
||||||
except VersionNotFoundError:
|
except VersionNotFoundError:
|
||||||
current_version = None
|
pass
|
||||||
return {
|
return {
|
||||||
"version": current_version,
|
"version": current_version,
|
||||||
"latest_version": version.utils.latest_version,
|
"latest_version": latest_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
def serve_images(self, name):
|
def serve_images(self, name):
|
||||||
|
|
@ -137,10 +140,10 @@ class Api:
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
|
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
|
||||||
def decorated_log(text: str):
|
def decorated_log(text: str, file = None):
|
||||||
debug.logs.append(text)
|
debug.logs.append(text)
|
||||||
if debug.logging:
|
if debug.logging:
|
||||||
debug.log_handler(text)
|
debug.log_handler(text, file)
|
||||||
debug.log = decorated_log
|
debug.log = decorated_log
|
||||||
proxy = os.environ.get("G4F_PROXY")
|
proxy = os.environ.get("G4F_PROXY")
|
||||||
provider = kwargs.get("provider")
|
provider = kwargs.get("provider")
|
||||||
|
|
@ -154,6 +157,7 @@ class Api:
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
debug.error(e)
|
||||||
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
||||||
return
|
return
|
||||||
if not isinstance(provider_handler, BaseRetryProvider):
|
if not isinstance(provider_handler, BaseRetryProvider):
|
||||||
|
|
@ -183,6 +187,7 @@ class Api:
|
||||||
yield self._format_json("conversation_id", conversation_id)
|
yield self._format_json("conversation_id", conversation_id)
|
||||||
elif isinstance(chunk, Exception):
|
elif isinstance(chunk, Exception):
|
||||||
logger.exception(chunk)
|
logger.exception(chunk)
|
||||||
|
debug.error(e)
|
||||||
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
||||||
elif isinstance(chunk, PreviewResponse):
|
elif isinstance(chunk, PreviewResponse):
|
||||||
yield self._format_json("preview", chunk.to_string())
|
yield self._format_json("preview", chunk.to_string())
|
||||||
|
|
@ -215,20 +220,19 @@ class Api:
|
||||||
yield self._format_json(chunk.type, **chunk.get_dict())
|
yield self._format_json(chunk.type, **chunk.get_dict())
|
||||||
else:
|
else:
|
||||||
yield self._format_json("content", str(chunk))
|
yield self._format_json("content", str(chunk))
|
||||||
if debug.logs:
|
yield from self._yield_logs()
|
||||||
for log in debug.logs:
|
|
||||||
yield self._format_json("log", str(log))
|
|
||||||
debug.logs = []
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
if debug.logging:
|
debug.error(e)
|
||||||
debug.log_handler(get_error_message(e))
|
yield from self._yield_logs()
|
||||||
if debug.logs:
|
|
||||||
for log in debug.logs:
|
|
||||||
yield self._format_json("log", str(log))
|
|
||||||
debug.logs = []
|
|
||||||
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
||||||
|
|
||||||
|
def _yield_logs(self):
|
||||||
|
if debug.logs:
|
||||||
|
for log in debug.logs:
|
||||||
|
yield self._format_json("log", log)
|
||||||
|
debug.logs = []
|
||||||
|
|
||||||
def _format_json(self, response_type: str, content = None, **kwargs):
|
def _format_json(self, response_type: str, content = None, **kwargs):
|
||||||
if content is not None and isinstance(response_type, str):
|
if content is not None and isinstance(response_type, str):
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class Backend_Api(Api):
|
||||||
@app.route('/', methods=['GET'])
|
@app.route('/', methods=['GET'])
|
||||||
@limiter.exempt
|
@limiter.exempt
|
||||||
def home():
|
def home():
|
||||||
return render_template('demo.html')
|
return render_template('demo.html', backend_url=os.environ.get("G4F_BACKEND_URL", ""))
|
||||||
else:
|
else:
|
||||||
@app.route('/', methods=['GET'])
|
@app.route('/', methods=['GET'])
|
||||||
def home():
|
def home():
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ async def copy_images(
|
||||||
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
|
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
|
||||||
|
|
||||||
except (ClientError, IOError, OSError) as e:
|
except (ClientError, IOError, OSError) as e:
|
||||||
debug.log(f"Image processing failed: {e.__class__.__name__}: {e}")
|
debug.error(f"Image processing failed: {type(e).__name__}: {e}")
|
||||||
if target_path and os.path.exists(target_path):
|
if target_path and os.path.exists(target_path):
|
||||||
os.unlink(target_path)
|
os.unlink(target_path)
|
||||||
return get_source_url(image, image)
|
return get_source_url(image, image)
|
||||||
|
|
|
||||||
|
|
@ -243,7 +243,7 @@ llama_3_3_70b = Model(
|
||||||
mixtral_8x7b = Model(
|
mixtral_8x7b = Model(
|
||||||
name = "mixtral-8x7b",
|
name = "mixtral-8x7b",
|
||||||
base_provider = "Mistral",
|
base_provider = "Mistral",
|
||||||
best_provider = IterListProvider([DDG, Jmuz])
|
best_provider = Jmuz
|
||||||
)
|
)
|
||||||
mixtral_8x22b = Model(
|
mixtral_8x22b = Model(
|
||||||
name = "mixtral-8x22b",
|
name = "mixtral-8x22b",
|
||||||
|
|
@ -300,7 +300,7 @@ wizardlm_2_8x22b = Model(
|
||||||
### Google DeepMind ###
|
### Google DeepMind ###
|
||||||
# gemini
|
# gemini
|
||||||
gemini = Model(
|
gemini = Model(
|
||||||
name = 'gemini',
|
name = 'gemini-2.0',
|
||||||
base_provider = 'Google',
|
base_provider = 'Google',
|
||||||
best_provider = Gemini
|
best_provider = Gemini
|
||||||
)
|
)
|
||||||
|
|
@ -316,13 +316,13 @@ gemini_exp = Model(
|
||||||
gemini_1_5_flash = Model(
|
gemini_1_5_flash = Model(
|
||||||
name = 'gemini-1.5-flash',
|
name = 'gemini-1.5-flash',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots])
|
best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
|
||||||
)
|
)
|
||||||
|
|
||||||
gemini_1_5_pro = Model(
|
gemini_1_5_pro = Model(
|
||||||
name = 'gemini-1.5-pro',
|
name = 'gemini-1.5-pro',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots])
|
best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
|
||||||
)
|
)
|
||||||
|
|
||||||
# gemini-2.0
|
# gemini-2.0
|
||||||
|
|
@ -713,6 +713,7 @@ class ModelUtils:
|
||||||
|
|
||||||
### Google ###
|
### Google ###
|
||||||
### Gemini
|
### Gemini
|
||||||
|
"gemini": gemini,
|
||||||
gemini.name: gemini,
|
gemini.name: gemini,
|
||||||
gemini_exp.name: gemini_exp,
|
gemini_exp.name: gemini_exp,
|
||||||
gemini_1_5_pro.name: gemini_1_5_pro,
|
gemini_1_5_pro.name: gemini_1_5_pro,
|
||||||
|
|
@ -812,7 +813,6 @@ class ModelUtils:
|
||||||
|
|
||||||
|
|
||||||
demo_models = {
|
demo_models = {
|
||||||
gpt_4o.name: [gpt_4o, [PollinationsAI, Blackbox]],
|
|
||||||
"default": [llama_3_2_11b, [HuggingFace]],
|
"default": [llama_3_2_11b, [HuggingFace]],
|
||||||
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
|
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
|
||||||
qvq_72b.name: [qvq_72b, [HuggingSpace]],
|
qvq_72b.name: [qvq_72b, [HuggingSpace]],
|
||||||
|
|
|
||||||
|
|
@ -409,6 +409,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||||
def get_cache_file(cls) -> Path:
|
def get_cache_file(cls) -> Path:
|
||||||
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
|
||||||
|
if auth_result is not None:
|
||||||
|
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||||
|
elif cache_file.exists():
|
||||||
|
cache_file.unlink()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_completion(
|
def create_completion(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -416,35 +424,25 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> CreateResult:
|
) -> CreateResult:
|
||||||
auth_result = AuthResult()
|
auth_result: AuthResult = None
|
||||||
cache_file = cls.get_cache_file()
|
cache_file = cls.get_cache_file()
|
||||||
try:
|
try:
|
||||||
if cache_file.exists():
|
if cache_file.exists():
|
||||||
with cache_file.open("r") as f:
|
with cache_file.open("r") as f:
|
||||||
auth_result = AuthResult(**json.load(f))
|
auth_result = AuthResult(**json.load(f))
|
||||||
else:
|
else:
|
||||||
auth_result = cls.on_auth(**kwargs)
|
raise MissingAuthError
|
||||||
for chunk in auth_result:
|
|
||||||
if hasattr(chunk, "get_dict"):
|
|
||||||
auth_result = chunk
|
|
||||||
else:
|
|
||||||
yield chunk
|
|
||||||
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||||
except (MissingAuthError, NoValidHarFileError):
|
except (MissingAuthError, NoValidHarFileError):
|
||||||
auth_result = cls.on_auth(**kwargs)
|
auth_result = cls.on_auth(**kwargs)
|
||||||
for chunk in auth_result:
|
for chunk in auth_result:
|
||||||
if hasattr(chunk, "get_dict"):
|
if isinstance(chunk, AuthResult):
|
||||||
auth_result = chunk
|
auth_result = chunk
|
||||||
else:
|
else:
|
||||||
yield chunk
|
yield chunk
|
||||||
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||||
finally:
|
finally:
|
||||||
if hasattr(auth_result, "get_dict"):
|
cls.write_cache_file(cache_file, auth_result)
|
||||||
data = auth_result.get_dict()
|
|
||||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
cache_file.write_text(json.dumps(data))
|
|
||||||
elif cache_file.exists():
|
|
||||||
cache_file.unlink()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
|
|
@ -453,19 +451,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
|
auth_result: AuthResult = None
|
||||||
|
cache_file = cls.get_cache_file()
|
||||||
try:
|
try:
|
||||||
auth_result = AuthResult()
|
|
||||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
|
||||||
if cache_file.exists():
|
if cache_file.exists():
|
||||||
with cache_file.open("r") as f:
|
with cache_file.open("r") as f:
|
||||||
auth_result = AuthResult(**json.load(f))
|
auth_result = AuthResult(**json.load(f))
|
||||||
else:
|
else:
|
||||||
auth_result = cls.on_auth_async(**kwargs)
|
raise MissingAuthError
|
||||||
async for chunk in auth_result:
|
|
||||||
if hasattr(chunk, "get_dict"):
|
|
||||||
auth_result = chunk
|
|
||||||
else:
|
|
||||||
yield chunk
|
|
||||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
@ -474,16 +467,16 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||||
cache_file.unlink()
|
cache_file.unlink()
|
||||||
auth_result = cls.on_auth_async(**kwargs)
|
auth_result = cls.on_auth_async(**kwargs)
|
||||||
async for chunk in auth_result:
|
async for chunk in auth_result:
|
||||||
if hasattr(chunk, "get_dict"):
|
if isinstance(chunk, AuthResult):
|
||||||
auth_result = chunk
|
auth_result = chunk
|
||||||
else:
|
else:
|
||||||
yield chunk
|
yield chunk
|
||||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
if cache_file is not None:
|
||||||
|
cls.write_cache_file(cache_file, auth_result)
|
||||||
|
cache_file = None
|
||||||
yield chunk
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
if hasattr(auth_result, "get_dict"):
|
if cache_file is not None:
|
||||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
cls.write_cache_file(cache_file, auth_result)
|
||||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
|
||||||
elif cache_file.exists():
|
|
||||||
cache_file.unlink()
|
|
||||||
|
|
@ -19,7 +19,6 @@ def quote_url(url: str) -> str:
|
||||||
|
|
||||||
def quote_title(title: str) -> str:
|
def quote_title(title: str) -> str:
|
||||||
if title:
|
if title:
|
||||||
title = title.strip()
|
|
||||||
title = " ".join(title.split())
|
title = " ".join(title.split())
|
||||||
return title.replace('[', '').replace(']', '')
|
return title.replace('[', '').replace(']', '')
|
||||||
return ""
|
return ""
|
||||||
|
|
@ -154,6 +153,7 @@ class Sources(ResponseType):
|
||||||
self.add_source(source)
|
self.add_source(source)
|
||||||
|
|
||||||
def add_source(self, source: dict[str, str]):
|
def add_source(self, source: dict[str, str]):
|
||||||
|
source = source if isinstance(source, dict) else {"url": source}
|
||||||
url = source.get("url", source.get("link", None))
|
url = source.get("url", source.get("link", None))
|
||||||
if url is not None:
|
if url is not None:
|
||||||
url = re.sub(r"[&?]utm_source=.+", "", url)
|
url = re.sub(r"[&?]utm_source=.+", "", url)
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ class IterListProvider(BaseRetryProvider):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exceptions[provider.__name__] = e
|
exceptions[provider.__name__] = e
|
||||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||||
if started:
|
if started:
|
||||||
raise e
|
raise e
|
||||||
yield e
|
yield e
|
||||||
|
|
@ -105,7 +105,7 @@ class IterListProvider(BaseRetryProvider):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exceptions[provider.__name__] = e
|
exceptions[provider.__name__] = e
|
||||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
debug.error(name=f"{provider.__name__} {type(e).__name__}: {e}")
|
||||||
if started:
|
if started:
|
||||||
raise e
|
raise e
|
||||||
yield e
|
yield e
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ from __future__ import annotations
|
||||||
|
|
||||||
from curl_cffi.requests import AsyncSession, Response
|
from curl_cffi.requests import AsyncSession, Response
|
||||||
try:
|
try:
|
||||||
from curl_cffi.requests import CurlMime
|
from curl_cffi import CurlMime
|
||||||
has_curl_mime = True
|
has_curl_mime = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_curl_mime = False
|
has_curl_mime = False
|
||||||
try:
|
try:
|
||||||
from curl_cffi.requests import CurlWsFlag
|
from curl_cffi import CurlWsFlag
|
||||||
has_curl_ws = True
|
has_curl_ws = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_curl_ws = False
|
has_curl_ws = False
|
||||||
|
|
@ -73,7 +73,7 @@ class StreamSession(AsyncSession):
|
||||||
def request(
|
def request(
|
||||||
self, method: str, url: str, ssl = None, **kwargs
|
self, method: str, url: str, ssl = None, **kwargs
|
||||||
) -> StreamResponse:
|
) -> StreamResponse:
|
||||||
if isinstance(kwargs.get("data"), CurlMime):
|
if kwargs.get("data") and isinstance(kwargs.get("data"), CurlMime):
|
||||||
kwargs["multipart"] = kwargs.pop("data")
|
kwargs["multipart"] = kwargs.pop("data")
|
||||||
"""Create and return a StreamResponse object for the given HTTP request."""
|
"""Create and return a StreamResponse object for the given HTTP request."""
|
||||||
return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs))
|
return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs))
|
||||||
|
|
@ -100,12 +100,12 @@ if has_curl_mime:
|
||||||
else:
|
else:
|
||||||
class FormData():
|
class FormData():
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]")
|
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U curl_cffi")
|
||||||
|
|
||||||
class WebSocket():
|
class WebSocket():
|
||||||
def __init__(self, session, url, **kwargs) -> None:
|
def __init__(self, session, url, **kwargs) -> None:
|
||||||
if not has_curl_ws:
|
if not has_curl_ws:
|
||||||
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U g4f[curl_cffi]")
|
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U curl_cffi")
|
||||||
self.session: StreamSession = session
|
self.session: StreamSession = session
|
||||||
self.url: str = url
|
self.url: str = url
|
||||||
del kwargs["autoping"]
|
del kwargs["autoping"]
|
||||||
|
|
@ -116,11 +116,13 @@ class WebSocket():
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, *args):
|
||||||
await self.inner.aclose()
|
await self.inner.aclose() if hasattr(self.inner, "aclose") else await self.inner.close()
|
||||||
|
|
||||||
async def receive_str(self, **kwargs) -> str:
|
async def receive_str(self, **kwargs) -> str:
|
||||||
bytes, _ = await self.inner.arecv()
|
method = self.inner.arecv if hasattr(self.inner, "arecv") else self.inner.recv
|
||||||
|
bytes, _ = await method()
|
||||||
return bytes.decode(errors="ignore")
|
return bytes.decode(errors="ignore")
|
||||||
|
|
||||||
async def send_str(self, data: str):
|
async def send_str(self, data: str):
|
||||||
await self.inner.asend(data.encode(), CurlWsFlag.TEXT)
|
method = self.inner.asend if hasattr(self.inner, "asend") else self.inner.send
|
||||||
|
await method(data.encode(), CurlWsFlag.TEXT)
|
||||||
71
g4f/tools/pydantic_ai.py
Normal file
71
g4f/tools/pydantic_ai.py
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from functools import partial
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from pydantic_ai.models import Model, KnownModelName, infer_model
|
||||||
|
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
|
||||||
|
|
||||||
|
from ..client import AsyncClient
|
||||||
|
|
||||||
|
@dataclass(init=False)
|
||||||
|
class AIModel(OpenAIModel):
|
||||||
|
"""A model that uses the G4F API."""
|
||||||
|
|
||||||
|
client: AsyncClient = field(repr=False)
|
||||||
|
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
||||||
|
|
||||||
|
_model_name: str = field(repr=False)
|
||||||
|
_provider: str = field(repr=False)
|
||||||
|
_system: Optional[str] = field(repr=False)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
provider: str | None = None,
|
||||||
|
*,
|
||||||
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
||||||
|
system: str | None = 'openai',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Initialize an AI model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The name of the AI model to use. List of model names available
|
||||||
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
||||||
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
||||||
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
||||||
|
In the future, this may be inferred from the model name.
|
||||||
|
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
|
||||||
|
customize the `base_url` and `api_key` to use a different provider.
|
||||||
|
"""
|
||||||
|
self._model_name = model_name
|
||||||
|
self._provider = provider
|
||||||
|
self.client = AsyncClient(provider=provider, **kwargs)
|
||||||
|
self.system_prompt_role = system_prompt_role
|
||||||
|
self._system = system
|
||||||
|
|
||||||
|
def name(self) -> str:
|
||||||
|
if self._provider:
|
||||||
|
return f'g4f:{self._provider}:{self._model_name}'
|
||||||
|
return f'g4f:{self._model_name}'
|
||||||
|
|
||||||
|
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
|
||||||
|
if isinstance(model, Model):
|
||||||
|
return model
|
||||||
|
if model.startswith("g4f:"):
|
||||||
|
model = model[4:]
|
||||||
|
if ":" in model:
|
||||||
|
provider, model = model.split(":", 1)
|
||||||
|
return AIModel(model, provider=provider, api_key=api_key)
|
||||||
|
return AIModel(model)
|
||||||
|
return infer_model(model)
|
||||||
|
|
||||||
|
def apply_patch(api_key: str | None = None):
|
||||||
|
import pydantic_ai.models
|
||||||
|
import pydantic_ai.models.openai
|
||||||
|
|
||||||
|
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
|
||||||
|
pydantic_ai.models.AIModel = AIModel
|
||||||
|
pydantic_ai.models.openai.NOT_GIVEN = None
|
||||||
|
|
@ -44,7 +44,7 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
|
||||||
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
|
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
|
||||||
messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
|
messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||||
# Keep web_search in kwargs for provider native support
|
# Keep web_search in kwargs for provider native support
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -82,7 +82,8 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
|
||||||
has_bucket = True
|
has_bucket = True
|
||||||
message["content"] = new_message_content
|
message["content"] = new_message_content
|
||||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
if "\nSource: " in messages[-1]["content"]:
|
||||||
|
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
||||||
create_function = provider.get_async_create_function()
|
create_function = provider.get_async_create_function()
|
||||||
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
|
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
|
@ -149,7 +150,7 @@ def iter_run_tools(
|
||||||
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
|
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
|
||||||
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
|
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||||
# Keep web_search in kwargs for provider native support
|
# Keep web_search in kwargs for provider native support
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -192,7 +193,8 @@ def iter_run_tools(
|
||||||
has_bucket = True
|
has_bucket = True
|
||||||
message["content"] = new_message_content
|
message["content"] = new_message_content
|
||||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
if "\nSource: " in messages[-1]["content"]:
|
||||||
|
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
|
||||||
|
|
||||||
thinking_start_time = 0
|
thinking_start_time = 0
|
||||||
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
|
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -237,7 +237,7 @@ def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) ->
|
||||||
except (DuckDuckGoSearchException, MissingRequirementsError) as e:
|
except (DuckDuckGoSearchException, MissingRequirementsError) as e:
|
||||||
if raise_search_exceptions:
|
if raise_search_exceptions:
|
||||||
raise e
|
raise e
|
||||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def spacy_get_keywords(text: str):
|
def spacy_get_keywords(text: str):
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,9 @@ def get_github_version(repo: str) -> str:
|
||||||
VersionNotFoundError: If there is an error in fetching the version from GitHub.
|
VersionNotFoundError: If there is an error in fetching the version from GitHub.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
|
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
|
||||||
return response["tag_name"]
|
response.raise_for_status()
|
||||||
|
return response.json()["tag_name"]
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue