Merge pull request #2590 from hlohaus/16Jan

Support TitleGeneration, Reasoning in HuggingChat
This commit is contained in:
H Lohaus 2025-01-24 03:23:25 +01:00 committed by GitHub
commit a9fde5bf88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 411 additions and 280 deletions

View file

@ -5,7 +5,7 @@ from .needs_auth.OpenaiAPI import OpenaiAPI
class Jmuz(OpenaiAPI):
label = "Jmuz"
url = "https://discord.gg/qXfu24JmsB"
url = "https://discord.gg/Ew6JzjA2NR"
login_url = None
api_base = "https://jmuz.me/gpt/api/v2"
api_key = "prod"
@ -18,12 +18,14 @@ class Jmuz(OpenaiAPI):
default_model = "gpt-4o"
model_aliases = {
"gemini": "gemini-exp",
"deepseek-chat": "deepseek-2.5",
"qwq-32b": "qwq-32b-preview"
"gemini-1.5-pro": "gemini-pro",
"gemini-1.5-flash": "gemini-thinking",
"deepseek-chat": "deepseek-v3",
"qwq-32b": "qwq-32b-preview",
}
@classmethod
def get_models(cls):
def get_models(cls, **kwargs):
if not cls.models:
cls.models = super().get_models(api_key=cls.api_key, api_base=cls.api_base)
return cls.models
@ -47,6 +49,7 @@ class Jmuz(OpenaiAPI):
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}
started = False
buffer = ""
async for chunk in super().create_async_generator(
model=model,
messages=messages,
@ -56,10 +59,25 @@ class Jmuz(OpenaiAPI):
headers=headers,
**kwargs
):
if isinstance(chunk, str) and cls.url in chunk:
continue
if isinstance(chunk, str) and not started:
chunk = chunk.lstrip()
if chunk:
started = True
if isinstance(chunk, str):
buffer += chunk
if "Join for free".startswith(buffer) or buffer.startswith("Join for free"):
if buffer.endswith("\n"):
buffer = ""
continue
if "https://discord.gg/".startswith(buffer) or "https://discord.gg/" in buffer:
if "..." in buffer:
buffer = ""
continue
if "o1-preview".startswith(buffer) or buffer.startswith("o1-preview"):
if "\n" in buffer:
buffer = ""
continue
if not started:
buffer = buffer.lstrip()
if buffer:
started = True
yield buffer
buffer = ""
else:
yield chunk

View file

@ -3,42 +3,45 @@ from __future__ import annotations
import json
import random
import requests
from urllib.parse import quote
from urllib.parse import quote_plus
from typing import Optional
from aiohttp import ClientSession
from .helper import filter_none
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType
from ..image import to_data_uri
from ..requests.raise_for_status import raise_for_status
from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, FinishReason, Usage
DEFAULT_HEADERS = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
}
class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Pollinations AI"
url = "https://pollinations.ai"
working = True
supports_stream = False
supports_system_message = True
supports_message_history = True
# API endpoints base
api_base = "https://text.pollinations.ai/openai"
# API endpoints
text_api_endpoint = "https://text.pollinations.ai/"
text_api_endpoint = "https://text.pollinations.ai/openai"
image_api_endpoint = "https://image.pollinations.ai/"
# Models configuration
default_model = "openai"
default_image_model = "flux"
image_models = []
models = []
additional_models_image = ["midjourney", "dall-e-3"]
additional_models_text = ["claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral"]
default_vision_model = "gpt-4o"
extra_image_models = ["midjourney", "dall-e-3", "flux-pro", "flux-realism", "flux-cablyai", "flux-anime", "flux-3d"]
vision_models = [default_vision_model, "gpt-4o-mini"]
extra_text_models = [*vision_models, "claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral", "any-dark"]
model_aliases = {
"gpt-4o": default_model,
"qwen-2-72b": "qwen",
"qwen-2.5-coder-32b": "qwen-coder",
"llama-3.3-70b": "llama",
@ -50,22 +53,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"deepseek-chat": "deepseek",
"llama-3.2-3b": "llamalight",
}
text_models = []
@classmethod
def get_models(cls, **kwargs):
# Initialize model lists if not exists
if not hasattr(cls, 'image_models'):
cls.image_models = []
if not hasattr(cls, 'text_models'):
cls.text_models = []
# Fetch image models if not cached
if not cls.image_models:
url = "https://image.pollinations.ai/models"
response = requests.get(url)
raise_for_status(response)
cls.image_models = response.json()
cls.image_models.extend(cls.additional_models_image)
cls.image_models.extend(cls.extra_image_models)
# Fetch text models if not cached
if not cls.text_models:
@ -73,7 +71,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
response = requests.get(url)
raise_for_status(response)
cls.text_models = [model.get("name") for model in response.json()]
cls.text_models.extend(cls.additional_models_text)
cls.text_models.extend(cls.extra_text_models)
# Return combined models
return cls.text_models + cls.image_models
@ -94,22 +92,27 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance: bool = False,
safe: bool = False,
# Text specific parameters
temperature: float = 0.5,
presence_penalty: float = 0,
images: ImagesType = None,
temperature: float = None,
presence_penalty: float = None,
top_p: float = 1,
frequency_penalty: float = 0,
stream: bool = False,
frequency_penalty: float = None,
response_format: Optional[dict] = None,
cache: bool = False,
**kwargs
) -> AsyncResult:
if images is not None and not model:
model = cls.default_vision_model
model = cls.get_model(model)
if not cache and seed is None:
seed = random.randint(0, 100000)
# Check if models
# Image generation
if model in cls.image_models:
async for result in cls._generate_image(
yield await cls._generate_image(
model=model,
messages=messages,
prompt=prompt,
prompt=messages[-1]["content"] if prompt is None else prompt,
proxy=proxy,
width=width,
height=height,
@ -118,19 +121,21 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private=private,
enhance=enhance,
safe=safe
):
yield result
)
else:
# Text generation
async for result in cls._generate_text(
model=model,
messages=messages,
images=images,
proxy=proxy,
temperature=temperature,
presence_penalty=presence_penalty,
top_p=top_p,
frequency_penalty=frequency_penalty,
stream=stream
response_format=response_format,
seed=seed,
cache=cache,
):
yield result
@ -138,7 +143,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async def _generate_image(
cls,
model: str,
messages: Messages,
prompt: str,
proxy: str,
width: int,
@ -148,16 +152,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private: bool,
enhance: bool,
safe: bool
) -> AsyncResult:
if seed is None:
seed = random.randint(0, 10000)
headers = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
}
) -> ImageResponse:
params = {
"seed": seed,
"width": width,
@ -168,42 +163,47 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"enhance": enhance,
"safe": safe
}
params = {k: v for k, v in params.items() if v is not None}
async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt
param_string = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{cls.image_api_endpoint}/prompt/{quote(prompt)}?{param_string}"
async with session.head(url, proxy=proxy) as response:
if response.status == 200:
image_response = ImageResponse(images=url, alt=prompt)
yield image_response
params = {k: json.dumps(v) if isinstance(v, bool) else v for k, v in params.items() if v is not None}
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with session.head(f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}", params=params) as response:
await raise_for_status(response)
return ImageResponse(str(response.url), prompt)
@classmethod
async def _generate_text(
cls,
model: str,
messages: Messages,
images: Optional[ImagesType],
proxy: str,
temperature: float,
presence_penalty: float,
top_p: float,
frequency_penalty: float,
stream: bool,
seed: Optional[int] = None
) -> AsyncResult:
headers = {
"accept": "*/*",
"accept-language": "en-US,en;q=0.9",
"content-type": "application/json",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
}
if seed is None:
seed = random.randint(0, 10000)
async with ClientSession(headers=headers) as session:
response_format: Optional[dict],
seed: Optional[int],
cache: bool
) -> AsyncResult:
jsonMode = False
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
jsonMode = True
if images is not None and messages:
last_message = messages[-1].copy()
last_message["content"] = [
*[{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
} for image, _ in images],
{
"type": "text",
"text": messages[-1]["content"]
}
]
messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
data = {
"messages": messages,
"model": model,
@ -211,42 +211,33 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"presence_penalty": presence_penalty,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"jsonMode": False,
"stream": stream,
"jsonMode": jsonMode,
"stream": False, # To get more informations like Usage and FinishReason
"seed": seed,
"cache": False
"cache": cache
}
async with session.post(cls.text_api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
async for chunk in response.content:
if chunk:
decoded_chunk = chunk.decode()
# Skip [DONE].
if "data: [DONE]" in decoded_chunk:
continue
# Processing plain text
if not decoded_chunk.startswith("data:"):
clean_text = decoded_chunk.strip()
if clean_text:
yield clean_text
continue
# Processing JSON format
try:
# Remove the prefix “data: “ and parse JSON
json_str = decoded_chunk.replace("data:", "").strip()
json_response = json.loads(json_str)
if "choices" in json_response and json_response["choices"]:
if "delta" in json_response["choices"][0]:
content = json_response["choices"][0]["delta"].get("content")
if content:
# Remove escaped slashes before parentheses
clean_content = content.replace("\\(", "(").replace("\\)", ")")
yield clean_content
except json.JSONDecodeError:
# If JSON could not be parsed, skip
continue
async with session.post(cls.text_api_endpoint, json=filter_none(**data)) as response:
await raise_for_status(response)
async for line in response.content:
decoded_chunk = line.decode(errors="replace")
# If [DONE].
if "data: [DONE]" in decoded_chunk:
break
# Processing JSON format
try:
# Remove the prefix “data: “ and parse JSON
json_str = decoded_chunk.replace("data:", "").strip()
data = json.loads(json_str)
choice = data["choices"][0]
if "usage" in data:
yield Usage(**data["usage"])
if "message" in choice and "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].replace("\\(", "(").replace("\\)", ")")
elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
yield choice["delta"]["content"].replace("\\(", "(").replace("\\)", ")")
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
break
except json.JSONDecodeError:
yield decoded_chunk.strip()
continue

View file

@ -18,6 +18,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
default_model = "qwen-qvq-72b-preview"
models = [default_model]
vision_models = models
model_aliases = {"qwq-32b": default_model}
@classmethod

View file

@ -33,12 +33,18 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
models = []
image_models = []
vision_models = []
for provider in cls.providers:
models.extend(provider.get_models(**kwargs))
models.extend(provider.model_aliases.keys())
image_models.extend(provider.image_models)
vision_models.extend(provider.vision_models)
models = list(set(models))
models.sort()
cls.models = models
cls.image_models = list(set(image_models))
cls.vision_models = list(set(vision_models))
return cls.models
@classmethod

View file

@ -9,15 +9,19 @@ from ...typing import AsyncResult, Messages
class Ollama(OpenaiAPI):
label = "Ollama"
url = "https://ollama.com"
login_url = None
needs_auth = False
working = True
@classmethod
def get_models(cls):
def get_models(cls, api_base: str = None, **kwargs):
if not cls.models:
host = os.getenv("OLLAMA_HOST", "127.0.0.1")
port = os.getenv("OLLAMA_PORT", "11434")
url = f"http://{host}:{port}/api/tags"
if api_base is None:
host = os.getenv("OLLAMA_HOST", "127.0.0.1")
port = os.getenv("OLLAMA_PORT", "11434")
url = f"http://{host}:{port}/api/tags"
else:
url = api_base.replace("/v1", "/api/tags")
models = requests.get(url).json()["models"]
cls.models = [model["name"] for model in models]
cls.default_model = cls.models[0]

View file

@ -1,6 +1,11 @@
from __future__ import annotations
import json
import re
import os
import requests
import base64
from typing import AsyncIterator
try:
from curl_cffi.requests import Session, CurlMime
@ -8,21 +13,22 @@ try:
except ImportError:
has_curl_cffi = False
from ..base_provider import ProviderModelMixin, AbstractProvider
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
from ..helper import format_prompt
from ...typing import CreateResult, Messages, Cookies
from ...errors import MissingRequirementsError
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
from ...image import to_bytes
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
from ...cookies import get_cookies
from ... import debug
class Conversation(JsonConversation):
def __init__(self, conversation_id: str, message_id: str):
self.conversation_id: str = conversation_id
self.message_id: str = message_id
def __init__(self, models: dict):
self.models: dict = models
class HuggingChat(AbstractProvider, ProviderModelMixin):
class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
url = "https://huggingface.co/chat"
working = True
@ -32,11 +38,11 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
default_model = "Qwen/Qwen2.5-72B-Instruct"
default_image_model = "black-forest-labs/FLUX.1-dev"
image_models = [
"black-forest-labs/FLUX.1-dev",
default_image_model,
"black-forest-labs/FLUX.1-schnell",
]
models = [
'Qwen/Qwen2.5-Coder-32B-Instruct',
fallback_models = [
default_model,
'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'Qwen/QwQ-32B-Preview',
@ -64,57 +70,86 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
}
@classmethod
def create_completion(
def get_models(cls):
if not cls.models:
try:
text = requests.get(cls.url).text
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
text = text.replace('void 0', 'null')
def add_quotation_mark(match):
return f'{match.group(1)}"{match.group(2)}":'
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
models = json.loads(text)
cls.text_models = [model["id"] for model in models]
cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
cls.models = [*cls.fallback_models]
return cls.models
@classmethod
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
if cookies is None:
cookies = get_cookies("huggingface.co", single_browser=True)
if "hf-chat" in cookies:
yield AuthResult(
cookies=cookies,
impersonate="chrome",
headers=DEFAULT_HEADERS
)
return
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield RequestLogin(cls.__name__, login_url)
yield AuthResult(
**await get_args_from_nodriver(
cls.url,
proxy=proxy,
wait_for='form[action="/chat/logout"]'
)
)
@classmethod
async def create_authed(
cls,
model: str,
messages: Messages,
stream: bool,
auth_result: AuthResult,
prompt: str = None,
images: ImagesType = None,
return_conversation: bool = False,
conversation: Conversation = None,
web_search: bool = False,
cookies: Cookies = None,
**kwargs
) -> CreateResult:
) -> AsyncResult:
if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
model = cls.get_model(model)
if cookies is None:
cookies = get_cookies("huggingface.co")
session = Session(cookies=cookies)
session.headers = {
'accept': '*/*',
'accept-language': 'en',
'cache-control': 'no-cache',
'origin': 'https://huggingface.co',
'pragma': 'no-cache',
'priority': 'u=1, i',
'referer': 'https://huggingface.co/chat/',
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
}
session = Session(**auth_result.get_dict())
if conversation is None:
if conversation is None or not hasattr(conversation, "models"):
conversation = Conversation({})
if model not in conversation.models:
conversationId = cls.create_conversation(session, model)
messageId = cls.fetch_message_id(session, conversationId)
conversation = Conversation(conversationId, messageId)
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
if return_conversation:
yield conversation
inputs = format_prompt(messages)
else:
conversation.message_id = cls.fetch_message_id(session, conversation.conversation_id)
conversationId = conversation.models[model]["conversationId"]
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
inputs = messages[-1]["content"]
debug.log(f"Use conversation: {conversation.conversation_id} Use message: {conversation.message_id}")
debug.log(f"Use: {json.dumps(conversation.models[model])}")
settings = {
"inputs": inputs,
"id": conversation.message_id,
"id": conversation.models[model]["messageId"],
"is_retry": False,
"is_continue": False,
"web_search": web_search,
@ -123,34 +158,27 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
headers = {
'accept': '*/*',
'accept-language': 'en',
'cache-control': 'no-cache',
'origin': 'https://huggingface.co',
'pragma': 'no-cache',
'priority': 'u=1, i',
'referer': f'https://huggingface.co/chat/conversation/{conversation.conversation_id}',
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
'referer': f'https://huggingface.co/chat/conversation/{conversationId}',
}
data = CurlMime()
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
if images is not None:
for image, filename in images:
data.addpart(
"files",
filename=f"base64;{filename}",
data=base64.b64encode(to_bytes(image))
)
response = session.post(
f'https://huggingface.co/chat/conversation/{conversation.conversation_id}',
cookies=session.cookies,
f'https://huggingface.co/chat/conversation/{conversationId}',
headers=headers,
multipart=data,
stream=True
)
raise_for_status(response)
full_response = ""
sources = None
for line in response.iter_lines():
if not line:
@ -163,21 +191,20 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
if "type" not in line:
raise RuntimeError(f"Response: {line}")
elif line["type"] == "stream":
token = line["token"].replace('\u0000', '')
full_response += token
if stream:
yield token
yield line["token"].replace('\u0000', '')
elif line["type"] == "finalAnswer":
break
elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
prompt = messages[-1]["content"] if prompt is None else prompt
yield ImageResponse(url, alt=prompt, options={"cookies": auth_result.cookies})
elif line["type"] == "webSearch" and "sources" in line:
sources = Sources(line["sources"])
elif line["type"] == "title":
yield TitleGeneration(line["title"])
elif line["type"] == "reasoning":
yield Reasoning(line.get("token"), line.get("status"))
full_response = full_response.replace('<|im_end|', '').strip()
if not stream:
yield full_response
if sources is not None:
yield sources
@ -189,8 +216,9 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'model': model,
}
response = session.post('https://huggingface.co/chat/conversation', json=json_data)
if response.status_code == 401:
raise MissingAuthError(response.text)
raise_for_status(response)
return response.json().get('conversationId')
@classmethod
@ -215,6 +243,11 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
if not json_data:
raise RuntimeError("Failed to parse response data")
if json_data["nodes"][-1]["type"] == "error":
if json_data["nodes"][-1]["status"] == 403:
raise MissingAuthError(json_data["nodes"][-1]["error"]["message"])
raise ResponseError(json.dumps(json_data["nodes"][-1]))
data = json_data["nodes"][1]["data"]
keys = data[data[0]["messages"]]
message_keys = data[keys[-1]]

View file

@ -143,7 +143,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
else:
is_special = True
debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"])
yield FinishReason("stop" if is_special else "length")
else:
if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))

View file

@ -2,6 +2,7 @@ from __future__ import annotations
from .OpenaiAPI import OpenaiAPI
from .HuggingChat import HuggingChat
from ...providers.types import Messages
class HuggingFaceAPI(OpenaiAPI):
label = "HuggingFace (Inference API)"
@ -11,6 +12,23 @@ class HuggingFaceAPI(OpenaiAPI):
working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model
models = [
*HuggingChat.models
]
@classmethod
def get_models(cls, **kwargs):
HuggingChat.get_models()
cls.models = HuggingChat.text_models
cls.vision_models = HuggingChat.vision_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = None,
**kwargs
):
if api_base is None:
api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
async for chunk in super().create_async_generator(model, messages, api_base=api_base, **kwargs):
yield chunk

View file

@ -73,10 +73,11 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
raise MissingAuthError('Add a "api_key"')
if api_base is None:
api_base = cls.api_base
if images is not None:
if images is not None and messages:
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
messages[-1]["content"] = [
last_message = messages[-1].copy()
last_message["content"] = [
*[{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
@ -86,6 +87,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
"text": messages[-1]["content"]
}
]
messages[-1] = last_message
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
@ -106,10 +108,10 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
if api_endpoint is None:
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
async with session.post(api_endpoint, json=data) as response:
await raise_for_status(response)
if not stream:
if not stream or response.headers.get("content-type") == "application/json":
data = await response.json()
cls.raise_error(data)
await raise_for_status(response)
choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip()
@ -117,10 +119,11 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data:
yield Usage(**data["usage"])
finish = cls.read_finish_reason(choice)
if finish is not None:
yield finish
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
return
else:
await raise_for_status(response)
first = True
async for line in response.iter_lines():
if line.startswith(b"data: "):
@ -137,16 +140,10 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
if delta:
first = False
yield delta
finish = cls.read_finish_reason(choice)
if finish is not None:
yield finish
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
break
@staticmethod
def read_finish_reason(choice: dict) -> Optional[FinishReason]:
if "finish_reason" in choice and choice["finish_reason"] is not None:
return FinishReason(choice["finish_reason"])
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {

View file

@ -495,8 +495,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
"headers": cls._headers,
"web_search": web_search,
})
actions = ["variant", "continue"] if conversation.finish_reason == "max_tokens" else ["variant"]
yield FinishReason(conversation.finish_reason, actions=actions)
yield FinishReason(conversation.finish_reason)
@classmethod
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:

View file

@ -1,61 +1,50 @@
from __future__ import annotations
from ...typing import CreateResult, Messages
from ..helper import filter_none
from .OpenaiAPI import OpenaiAPI
models = {
"theb-ai": "TheB.AI",
"gpt-3.5-turbo": "GPT-3.5",
"gpt-3.5-turbo-16k": "GPT-3.5-16K",
"gpt-4-turbo": "GPT-4 Turbo",
"gpt-4": "GPT-4",
"gpt-4-32k": "GPT-4 32K",
"claude-2": "Claude 2",
"claude-1": "Claude",
"claude-1-100k": "Claude 100K",
"claude-instant-1": "Claude Instant",
"claude-instant-1-100k": "Claude Instant 100K",
"palm-2": "PaLM 2",
"palm-2-codey": "Codey",
"vicuna-13b-v1.5": "Vicuna v1.5 13B",
"claude-3.5-sonnet": "Claude",
"llama-2-7b-chat": "Llama 2 7B",
"llama-2-13b-chat": "Llama 2 13B",
"llama-2-70b-chat": "Llama 2 70B",
"code-llama-7b": "Code Llama 7B",
"code-llama-13b": "Code Llama 13B",
"code-llama-34b": "Code Llama 34B",
"qwen-7b-chat": "Qwen 7B"
"qwen-2-72b": "Qwen"
}
class ThebApi(OpenaiAPI):
label = "TheB.AI API"
url = "https://theb.ai"
login_url = "https://beta.theb.ai/home"
working = True
api_base = "https://api.theb.ai/v1"
needs_auth = True
default_model = "gpt-3.5-turbo"
models = list(models)
default_model = "theb-ai"
fallback_models = list(models)
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
temperature: float = 1,
top_p: float = 1,
temperature: float = None,
top_p: float = None,
**kwargs
) -> CreateResult:
if "auth" in kwargs:
kwargs["api_key"] = kwargs["auth"]
system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"])
if not system_message:
system_message = "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."
messages = [message for message in messages if message["role"] != "system"]
data = {
"model_params": {
"system_prompt": system_message,
"temperature": temperature,
"top_p": top_p,
}
"model_params": filter_none(
system_prompt=system_message,
temperature=temperature,
top_p=top_p,
)
}
return super().create_async_generator(model, messages, extra_data=data, **kwargs)

View file

@ -376,6 +376,29 @@ body:not(.white) a:visited{
display: flex;
}
.message .reasoning_text.final:not(.hidden), .message .reasoning_title {
margin-bottom: var(--inner-gap);
padding-bottom: var(--inner-gap);
border-bottom: 1px solid var(--colour-3);
overflow: hidden;
}
.message .reasoning_text.final {
max-height: 1000px;
transition: max-height 0.25s ease-in;
}
.message .reasoning_text.final.hidden {
transition: max-height 0.15s ease-out;
max-height: 0;
display: block;
overflow: hidden;
}
.message .reasoning_title {
cursor: pointer;
}
.message .user i {
position: absolute;
bottom: -6px;

View file

@ -35,6 +35,7 @@ let title_storage = {};
let parameters_storage = {};
let finish_storage = {};
let usage_storage = {};
let reasoning_storage = {}
messageInput.addEventListener("blur", () => {
window.scrollTo(0, 0);
@ -70,6 +71,17 @@ if (window.markdownit) {
}
}
function render_reasoning(reasoning, final = false) {
return `<div class="reasoning_body">
<div class="reasoning_title">
<strong>Reasoning <i class="fa-solid fa-brain"></i>:</strong> ${escapeHtml(reasoning.status)}
</div>
<div class="reasoning_text${final ? " final hidden" : ""}">
${markdown_render(reasoning.text)}
</div>
</div>`;
}
function filter_message(text) {
return text.replaceAll(
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
@ -169,7 +181,7 @@ const get_message_el = (el) => {
}
const register_message_buttons = async () => {
document.querySelectorAll(".message .content .provider").forEach(async (el) => {
message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
const provider_forms = document.querySelector(".provider_forms");
@ -192,7 +204,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -203,7 +215,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -226,7 +238,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -244,7 +256,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -270,7 +282,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -282,7 +294,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .continue_button").forEach(async (el) => {
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -297,7 +309,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -307,7 +319,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-print").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@ -323,6 +335,16 @@ const register_message_buttons = async () => {
})
}
});
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
let text_el = el.parentElement.querySelector(".reasoning_text");
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
})
}
});
}
const delete_conversations = async () => {
@ -469,7 +491,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
messages.forEach((message) => {
message_copy = { ...message };
if (last_message) {
if (last_message["role"] == message["role"]) {
if (last_message["role"] == message["role"] && message["role"] == "assistant") {
message_copy["content"] = last_message["content"] + message_copy["content"];
new_messages.pop();
}
@ -515,6 +537,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
delete new_message.synthesize;
delete new_message.finish;
delete new_message.usage;
delete new_message.reasoning;
delete new_message.conversation;
delete new_message.continue;
// Append message to new messages
@ -711,11 +734,21 @@ async function add_message_chunk(message, message_id, provider, scroll) {
} else if (message.type == "title") {
title_storage[message_id] = message.title;
} else if (message.type == "login") {
update_message(content_map, message_id, message.login, scroll);
update_message(content_map, message_id, markdown_render(message.login), scroll);
} else if (message.type == "finish") {
finish_storage[message_id] = message.finish;
} else if (message.type == "usage") {
usage_storage[message_id] = message.usage;
} else if (message.type == "reasoning") {
if (!reasoning_storage[message_id]) {
reasoning_storage[message_id] = message;
reasoning_storage[message_id].text = "";
} else if (message.status) {
reasoning_storage[message_id].status = message.status;
} else if (message.token) {
reasoning_storage[message_id].text += message.token;
}
update_message(content_map, message_id, render_reasoning(reasoning_storage[message_id]), scroll);
} else if (message.type == "parameters") {
if (!parameters_storage[provider]) {
parameters_storage[provider] = {};
@ -846,6 +879,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
title_storage[message_id],
finish_storage[message_id],
usage_storage[message_id],
reasoning_storage[message_id],
action=="continue"
);
delete controller_storage[message_id];
@ -1042,6 +1076,7 @@ function merge_messages(message1, message2) {
const load_conversation = async (conversation_id, scroll=true) => {
let conversation = await get_conversation(conversation_id);
let messages = conversation?.items || [];
console.debug("Conversation:", conversation)
if (!conversation) {
return;
@ -1098,11 +1133,8 @@ const load_conversation = async (conversation_id, scroll=true) => {
let add_buttons = [];
// Find buttons to add
actions = ["variant"]
if (item.finish && item.finish.actions) {
actions = item.finish.actions
}
// Add continue button if possible
if (item.role == "assistant" && !actions.includes("continue")) {
if (item.role == "assistant") {
let reason = "stop";
// Read finish reason from conversation
if (item.finish && item.finish.reason) {
@ -1167,7 +1199,10 @@ const load_conversation = async (conversation_id, scroll=true) => {
</div>
<div class="content">
${provider}
<div class="content_inner">${markdown_render(buffer)}</div>
<div class="content_inner">
${item.reasoning ? render_reasoning(item.reasoning, true): ""}
${markdown_render(buffer)}
</div>
<div class="count">
${count_words_and_tokens(buffer, next_provider?.model, completion_tokens, prompt_tokens)}
${add_buttons.join("")}
@ -1298,6 +1333,7 @@ const add_message = async (
title = null,
finish = null,
usage = null,
reasoning = null,
do_continue = false
) => {
const conversation = await get_conversation(conversation_id);
@ -1329,6 +1365,9 @@ const add_message = async (
if (usage) {
new_message.usage = usage;
}
if (reasoning) {
new_message.reasoning = reasoning;
}
if (do_continue) {
new_message.continue = true;
}
@ -1604,23 +1643,24 @@ function count_words_and_tokens(text, model, completion_tokens, prompt_tokens) {
function update_message(content_map, message_id, content = null, scroll = true) {
content_map.update_timeouts.push(setTimeout(() => {
if (!content) content = message_storage[message_id];
html = markdown_render(content);
let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = html.lastIndexOf(element)
if (index - element.length > lastIndex) {
lastElement = element;
lastIndex = index;
if (!content) {
content = markdown_render(message_storage[message_id]);
let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = content.lastIndexOf(element)
if (index - element.length > lastIndex) {
lastElement = element;
lastIndex = index;
}
}
if (lastIndex) {
content = content.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
}
}
if (lastIndex) {
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
}
content_map.inner.innerHTML = content;
if (error_storage[message_id]) {
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${error_storage[message_id]}`);
}
content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner);
if (scroll) {
@ -2132,9 +2172,9 @@ async function read_response(response, message_id, provider, scroll) {
function get_api_key_by_provider(provider) {
let api_key = null;
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.id || null;
api_key = document.querySelector(`.${provider}-api_key`)?.id || null;
if (api_key == null) {
api_key = document.querySelector(`.${provider}-api_key`)?.id || null;
api_key = document.getElementById(`${provider}-api_key`)?.id || null;
}
if (api_key) {
api_key = appStorage.getItem(api_key);

View file

@ -13,7 +13,7 @@ from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin
from ...providers.retry_provider import IterListProvider
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
from ... import version, models
from ... import ChatCompletion, get_model_and_provider
@ -207,6 +207,8 @@ class Api:
yield self._format_json("finish", chunk.get_dict())
elif isinstance(chunk, Usage):
yield self._format_json("usage", chunk.get_dict())
elif isinstance(chunk, Reasoning):
yield self._format_json("reasoning", token=chunk.token, status=chunk.status)
else:
yield self._format_json("content", str(chunk))
if debug.logs:
@ -219,10 +221,15 @@ class Api:
if first:
yield self.handle_provider(provider_handler, model)
def _format_json(self, response_type: str, content):
def _format_json(self, response_type: str, content = None, **kwargs):
if content is not None:
return {
'type': response_type,
response_type: content,
}
return {
'type': response_type,
response_type: content
**kwargs
}
def handle_provider(self, provider_handler, model):

View file

@ -309,7 +309,7 @@ class Backend_Api(Api):
return "Provider not found", 404
return models
def _format_json(self, response_type: str, content) -> str:
def _format_json(self, response_type: str, content = None, **kwargs) -> str:
"""
Formats and returns a JSON response.
@ -320,4 +320,4 @@ class Backend_Api(Api):
Returns:
str: A JSON formatted string.
"""
return json.dumps(super()._format_json(response_type, content)) + "\n"
return json.dumps(super()._format_json(response_type, content, **kwargs)) + "\n"

View file

@ -340,7 +340,8 @@ class ProviderModelMixin:
default_model: str = None
models: list[str] = []
model_aliases: dict[str, str] = {}
image_models: list = None
image_models: list = []
vision_models: list = []
last_model: str = None
@classmethod

View file

@ -89,9 +89,8 @@ class JsonMixin:
self.__dict__ = {}
class FinishReason(ResponseType, JsonMixin):
def __init__(self, reason: str, actions: list[str] = None) -> None:
def __init__(self, reason: str) -> None:
self.reason = reason
self.actions = actions
def __str__(self) -> str:
return ""
@ -121,6 +120,14 @@ class TitleGeneration(ResponseType):
def __str__(self) -> str:
return ""
class Reasoning(ResponseType):
def __init__(self, token: str = None, status: str = None) -> None:
self.token = token
self.status = status
def __str__(self) -> str:
return "" if self.token is None else self.token
class Sources(ResponseType):
def __init__(self, sources: list[dict[str, str]]) -> None:
self.list = []

View file

@ -78,25 +78,22 @@ async def get_args_from_nodriver(
url: str,
proxy: str = None,
timeout: int = 120,
wait_for: str = None,
cookies: Cookies = None
) -> dict:
if not has_nodriver:
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
browser = await get_nodriver(proxy=proxy)
if debug.logging:
print(f"Open nodriver with url: {url}")
browser = await nodriver.start(
browser_args=None if proxy is None else [f"--proxy-server={proxy}"],
)
domain = urlparse(url).netloc
if cookies is None:
cookies = {}
else:
await browser.cookies.set_all(get_cookie_params_from_dict(cookies, url=url, domain=domain))
page = await browser.get(url)
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
cookies[c.name] = c.value
user_agent = await page.evaluate("window.navigator.userAgent")
await page.wait_for("body:not(.no-js)", timeout=timeout)
if wait_for is not None:
await page.wait_for(wait_for, timeout=timeout)
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
cookies[c.name] = c.value
await page.close()
@ -120,13 +117,13 @@ def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
async def get_nodriver(proxy: str = None, user_data_dir = "nodriver", browser_executable_path=None, **kwargs)-> Browser:
if not has_nodriver:
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
raise MissingRequirementsError('Install "nodriver" and "platformdirs" package | pip install -U nodriver platformdirs')
user_data_dir = user_config_dir(f"g4f-{user_data_dir}") if has_platformdirs else None
if browser_executable_path is None:
try:
browser_executable_path = find_chrome_executable()
except FileNotFoundError:
# Default to Edge if Chrome is not found
# Default to Edge if Chrome is not available.
browser_executable_path = "C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe"
if not os.path.exists(browser_executable_path):
browser_executable_path = None

View file

@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
return
text = await response.text()
if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else text
if message == "HTML content":
if response.status == 520:
message = "Unknown error (Cloudflare)"
@ -46,7 +46,7 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
if response.ok:
return
if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else response.text
message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else response.text
if message == "HTML content":
if response.status_code == 520:
message = "Unknown error (Cloudflare)"