mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Merge pull request #2590 from hlohaus/16Jan
Support TitleGeneration, Reasoning in HuggingChat
This commit is contained in:
commit
a9fde5bf88
19 changed files with 411 additions and 280 deletions
|
|
@ -5,7 +5,7 @@ from .needs_auth.OpenaiAPI import OpenaiAPI
|
|||
|
||||
class Jmuz(OpenaiAPI):
|
||||
label = "Jmuz"
|
||||
url = "https://discord.gg/qXfu24JmsB"
|
||||
url = "https://discord.gg/Ew6JzjA2NR"
|
||||
login_url = None
|
||||
api_base = "https://jmuz.me/gpt/api/v2"
|
||||
api_key = "prod"
|
||||
|
|
@ -18,12 +18,14 @@ class Jmuz(OpenaiAPI):
|
|||
default_model = "gpt-4o"
|
||||
model_aliases = {
|
||||
"gemini": "gemini-exp",
|
||||
"deepseek-chat": "deepseek-2.5",
|
||||
"qwq-32b": "qwq-32b-preview"
|
||||
"gemini-1.5-pro": "gemini-pro",
|
||||
"gemini-1.5-flash": "gemini-thinking",
|
||||
"deepseek-chat": "deepseek-v3",
|
||||
"qwq-32b": "qwq-32b-preview",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
def get_models(cls, **kwargs):
|
||||
if not cls.models:
|
||||
cls.models = super().get_models(api_key=cls.api_key, api_base=cls.api_base)
|
||||
return cls.models
|
||||
|
|
@ -47,6 +49,7 @@ class Jmuz(OpenaiAPI):
|
|||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
|
||||
}
|
||||
started = False
|
||||
buffer = ""
|
||||
async for chunk in super().create_async_generator(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
@ -56,10 +59,25 @@ class Jmuz(OpenaiAPI):
|
|||
headers=headers,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(chunk, str) and cls.url in chunk:
|
||||
if isinstance(chunk, str):
|
||||
buffer += chunk
|
||||
if "Join for free".startswith(buffer) or buffer.startswith("Join for free"):
|
||||
if buffer.endswith("\n"):
|
||||
buffer = ""
|
||||
continue
|
||||
if isinstance(chunk, str) and not started:
|
||||
chunk = chunk.lstrip()
|
||||
if chunk:
|
||||
if "https://discord.gg/".startswith(buffer) or "https://discord.gg/" in buffer:
|
||||
if "..." in buffer:
|
||||
buffer = ""
|
||||
continue
|
||||
if "o1-preview".startswith(buffer) or buffer.startswith("o1-preview"):
|
||||
if "\n" in buffer:
|
||||
buffer = ""
|
||||
continue
|
||||
if not started:
|
||||
buffer = buffer.lstrip()
|
||||
if buffer:
|
||||
started = True
|
||||
yield buffer
|
||||
buffer = ""
|
||||
else:
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -3,14 +3,23 @@ from __future__ import annotations
|
|||
import json
|
||||
import random
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus
|
||||
from typing import Optional
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from .helper import filter_none
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..image import to_data_uri
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..image import ImageResponse
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..providers.response import ImageResponse, FinishReason, Usage
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
'Accept': '*/*',
|
||||
'Accept-Language': 'en-US,en;q=0.9',
|
||||
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
|
||||
}
|
||||
|
||||
class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Pollinations AI"
|
||||
|
|
@ -21,24 +30,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
||||
# API endpoints base
|
||||
api_base = "https://text.pollinations.ai/openai"
|
||||
|
||||
# API endpoints
|
||||
text_api_endpoint = "https://text.pollinations.ai/"
|
||||
text_api_endpoint = "https://text.pollinations.ai/openai"
|
||||
image_api_endpoint = "https://image.pollinations.ai/"
|
||||
|
||||
# Models configuration
|
||||
default_model = "openai"
|
||||
default_image_model = "flux"
|
||||
|
||||
image_models = []
|
||||
models = []
|
||||
|
||||
additional_models_image = ["midjourney", "dall-e-3"]
|
||||
additional_models_text = ["claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral"]
|
||||
default_vision_model = "gpt-4o"
|
||||
extra_image_models = ["midjourney", "dall-e-3", "flux-pro", "flux-realism", "flux-cablyai", "flux-anime", "flux-3d"]
|
||||
vision_models = [default_vision_model, "gpt-4o-mini"]
|
||||
extra_text_models = [*vision_models, "claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral", "any-dark"]
|
||||
model_aliases = {
|
||||
"gpt-4o": default_model,
|
||||
"qwen-2-72b": "qwen",
|
||||
"qwen-2.5-coder-32b": "qwen-coder",
|
||||
"llama-3.3-70b": "llama",
|
||||
|
|
@ -50,22 +53,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"deepseek-chat": "deepseek",
|
||||
"llama-3.2-3b": "llamalight",
|
||||
}
|
||||
text_models = []
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs):
|
||||
# Initialize model lists if not exists
|
||||
if not hasattr(cls, 'image_models'):
|
||||
cls.image_models = []
|
||||
if not hasattr(cls, 'text_models'):
|
||||
cls.text_models = []
|
||||
|
||||
# Fetch image models if not cached
|
||||
if not cls.image_models:
|
||||
url = "https://image.pollinations.ai/models"
|
||||
response = requests.get(url)
|
||||
raise_for_status(response)
|
||||
cls.image_models = response.json()
|
||||
cls.image_models.extend(cls.additional_models_image)
|
||||
cls.image_models.extend(cls.extra_image_models)
|
||||
|
||||
# Fetch text models if not cached
|
||||
if not cls.text_models:
|
||||
|
|
@ -73,7 +71,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
response = requests.get(url)
|
||||
raise_for_status(response)
|
||||
cls.text_models = [model.get("name") for model in response.json()]
|
||||
cls.text_models.extend(cls.additional_models_text)
|
||||
cls.text_models.extend(cls.extra_text_models)
|
||||
|
||||
# Return combined models
|
||||
return cls.text_models + cls.image_models
|
||||
|
|
@ -94,22 +92,27 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
enhance: bool = False,
|
||||
safe: bool = False,
|
||||
# Text specific parameters
|
||||
temperature: float = 0.5,
|
||||
presence_penalty: float = 0,
|
||||
images: ImagesType = None,
|
||||
temperature: float = None,
|
||||
presence_penalty: float = None,
|
||||
top_p: float = 1,
|
||||
frequency_penalty: float = 0,
|
||||
stream: bool = False,
|
||||
frequency_penalty: float = None,
|
||||
response_format: Optional[dict] = None,
|
||||
cache: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if images is not None and not model:
|
||||
model = cls.default_vision_model
|
||||
model = cls.get_model(model)
|
||||
if not cache and seed is None:
|
||||
seed = random.randint(0, 100000)
|
||||
|
||||
# Check if models
|
||||
# Image generation
|
||||
if model in cls.image_models:
|
||||
async for result in cls._generate_image(
|
||||
yield await cls._generate_image(
|
||||
model=model,
|
||||
messages=messages,
|
||||
prompt=prompt,
|
||||
prompt=messages[-1]["content"] if prompt is None else prompt,
|
||||
proxy=proxy,
|
||||
width=width,
|
||||
height=height,
|
||||
|
|
@ -118,19 +121,21 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
private=private,
|
||||
enhance=enhance,
|
||||
safe=safe
|
||||
):
|
||||
yield result
|
||||
)
|
||||
else:
|
||||
# Text generation
|
||||
async for result in cls._generate_text(
|
||||
model=model,
|
||||
messages=messages,
|
||||
images=images,
|
||||
proxy=proxy,
|
||||
temperature=temperature,
|
||||
presence_penalty=presence_penalty,
|
||||
top_p=top_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
stream=stream
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
cache=cache,
|
||||
):
|
||||
yield result
|
||||
|
||||
|
|
@ -138,7 +143,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
async def _generate_image(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str,
|
||||
proxy: str,
|
||||
width: int,
|
||||
|
|
@ -148,16 +152,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
private: bool,
|
||||
enhance: bool,
|
||||
safe: bool
|
||||
) -> AsyncResult:
|
||||
if seed is None:
|
||||
seed = random.randint(0, 10000)
|
||||
|
||||
headers = {
|
||||
'Accept': '*/*',
|
||||
'Accept-Language': 'en-US,en;q=0.9',
|
||||
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
|
||||
}
|
||||
|
||||
) -> ImageResponse:
|
||||
params = {
|
||||
"seed": seed,
|
||||
"width": width,
|
||||
|
|
@ -168,42 +163,47 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"enhance": enhance,
|
||||
"safe": safe
|
||||
}
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
||||
param_string = "&".join(f"{k}={v}" for k, v in params.items())
|
||||
url = f"{cls.image_api_endpoint}/prompt/{quote(prompt)}?{param_string}"
|
||||
|
||||
async with session.head(url, proxy=proxy) as response:
|
||||
if response.status == 200:
|
||||
image_response = ImageResponse(images=url, alt=prompt)
|
||||
yield image_response
|
||||
params = {k: json.dumps(v) if isinstance(v, bool) else v for k, v in params.items() if v is not None}
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
async with session.head(f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}", params=params) as response:
|
||||
await raise_for_status(response)
|
||||
return ImageResponse(str(response.url), prompt)
|
||||
|
||||
@classmethod
|
||||
async def _generate_text(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
images: Optional[ImagesType],
|
||||
proxy: str,
|
||||
temperature: float,
|
||||
presence_penalty: float,
|
||||
top_p: float,
|
||||
frequency_penalty: float,
|
||||
stream: bool,
|
||||
seed: Optional[int] = None
|
||||
response_format: Optional[dict],
|
||||
seed: Optional[int],
|
||||
cache: bool
|
||||
) -> AsyncResult:
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"content-type": "application/json",
|
||||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
jsonMode = False
|
||||
if response_format is not None and "type" in response_format:
|
||||
if response_format["type"] == "json_object":
|
||||
jsonMode = True
|
||||
|
||||
if images is not None and messages:
|
||||
last_message = messages[-1].copy()
|
||||
last_message["content"] = [
|
||||
*[{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(image)}
|
||||
} for image, _ in images],
|
||||
{
|
||||
"type": "text",
|
||||
"text": messages[-1]["content"]
|
||||
}
|
||||
]
|
||||
messages[-1] = last_message
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 10000)
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
data = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
|
|
@ -211,42 +211,33 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"presence_penalty": presence_penalty,
|
||||
"top_p": top_p,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"jsonMode": False,
|
||||
"stream": stream,
|
||||
"jsonMode": jsonMode,
|
||||
"stream": False, # To get more informations like Usage and FinishReason
|
||||
"seed": seed,
|
||||
"cache": False
|
||||
"cache": cache
|
||||
}
|
||||
|
||||
async with session.post(cls.text_api_endpoint, json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.content:
|
||||
if chunk:
|
||||
decoded_chunk = chunk.decode()
|
||||
|
||||
# Skip [DONE].
|
||||
async with session.post(cls.text_api_endpoint, json=filter_none(**data)) as response:
|
||||
await raise_for_status(response)
|
||||
async for line in response.content:
|
||||
decoded_chunk = line.decode(errors="replace")
|
||||
# If [DONE].
|
||||
if "data: [DONE]" in decoded_chunk:
|
||||
continue
|
||||
|
||||
# Processing plain text
|
||||
if not decoded_chunk.startswith("data:"):
|
||||
clean_text = decoded_chunk.strip()
|
||||
if clean_text:
|
||||
yield clean_text
|
||||
continue
|
||||
|
||||
break
|
||||
# Processing JSON format
|
||||
try:
|
||||
# Remove the prefix “data: “ and parse JSON
|
||||
json_str = decoded_chunk.replace("data:", "").strip()
|
||||
json_response = json.loads(json_str)
|
||||
|
||||
if "choices" in json_response and json_response["choices"]:
|
||||
if "delta" in json_response["choices"][0]:
|
||||
content = json_response["choices"][0]["delta"].get("content")
|
||||
if content:
|
||||
# Remove escaped slashes before parentheses
|
||||
clean_content = content.replace("\\(", "(").replace("\\)", ")")
|
||||
yield clean_content
|
||||
data = json.loads(json_str)
|
||||
choice = data["choices"][0]
|
||||
if "usage" in data:
|
||||
yield Usage(**data["usage"])
|
||||
if "message" in choice and "content" in choice["message"] and choice["message"]["content"]:
|
||||
yield choice["message"]["content"].replace("\\(", "(").replace("\\)", ")")
|
||||
elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
|
||||
yield choice["delta"]["content"].replace("\\(", "(").replace("\\)", ")")
|
||||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
yield FinishReason(choice["finish_reason"])
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
# If JSON could not be parsed, skip
|
||||
yield decoded_chunk.strip()
|
||||
continue
|
||||
|
|
@ -18,6 +18,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
|
||||
default_model = "qwen-qvq-72b-preview"
|
||||
models = [default_model]
|
||||
vision_models = models
|
||||
model_aliases = {"qwq-32b": default_model}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -33,12 +33,18 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
def get_models(cls, **kwargs) -> list[str]:
|
||||
if not cls.models:
|
||||
models = []
|
||||
image_models = []
|
||||
vision_models = []
|
||||
for provider in cls.providers:
|
||||
models.extend(provider.get_models(**kwargs))
|
||||
models.extend(provider.model_aliases.keys())
|
||||
image_models.extend(provider.image_models)
|
||||
vision_models.extend(provider.vision_models)
|
||||
models = list(set(models))
|
||||
models.sort()
|
||||
cls.models = models
|
||||
cls.image_models = list(set(image_models))
|
||||
cls.vision_models = list(set(vision_models))
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -9,15 +9,19 @@ from ...typing import AsyncResult, Messages
|
|||
class Ollama(OpenaiAPI):
|
||||
label = "Ollama"
|
||||
url = "https://ollama.com"
|
||||
login_url = None
|
||||
needs_auth = False
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
def get_models(cls, api_base: str = None, **kwargs):
|
||||
if not cls.models:
|
||||
if api_base is None:
|
||||
host = os.getenv("OLLAMA_HOST", "127.0.0.1")
|
||||
port = os.getenv("OLLAMA_PORT", "11434")
|
||||
url = f"http://{host}:{port}/api/tags"
|
||||
else:
|
||||
url = api_base.replace("/v1", "/api/tags")
|
||||
models = requests.get(url).json()["models"]
|
||||
cls.models = [model["name"] for model in models]
|
||||
cls.default_model = cls.models[0]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import requests
|
||||
import base64
|
||||
from typing import AsyncIterator
|
||||
|
||||
try:
|
||||
from curl_cffi.requests import Session, CurlMime
|
||||
|
|
@ -8,21 +13,22 @@ try:
|
|||
except ImportError:
|
||||
has_curl_cffi = False
|
||||
|
||||
from ..base_provider import ProviderModelMixin, AbstractProvider
|
||||
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
||||
from ..helper import format_prompt
|
||||
from ...typing import CreateResult, Messages, Cookies
|
||||
from ...errors import MissingRequirementsError
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
|
||||
from ...image import to_bytes
|
||||
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...providers.response import JsonConversation, ImageResponse, Sources
|
||||
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
|
||||
from ...cookies import get_cookies
|
||||
from ... import debug
|
||||
|
||||
class Conversation(JsonConversation):
|
||||
def __init__(self, conversation_id: str, message_id: str):
|
||||
self.conversation_id: str = conversation_id
|
||||
self.message_id: str = message_id
|
||||
def __init__(self, models: dict):
|
||||
self.models: dict = models
|
||||
|
||||
class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
url = "https://huggingface.co/chat"
|
||||
|
||||
working = True
|
||||
|
|
@ -32,11 +38,11 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
default_model = "Qwen/Qwen2.5-72B-Instruct"
|
||||
default_image_model = "black-forest-labs/FLUX.1-dev"
|
||||
image_models = [
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
default_image_model,
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
]
|
||||
models = [
|
||||
'Qwen/Qwen2.5-Coder-32B-Instruct',
|
||||
fallback_models = [
|
||||
default_model,
|
||||
'meta-llama/Llama-3.3-70B-Instruct',
|
||||
'CohereForAI/c4ai-command-r-plus-08-2024',
|
||||
'Qwen/QwQ-32B-Preview',
|
||||
|
|
@ -64,57 +70,86 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
def get_models(cls):
|
||||
if not cls.models:
|
||||
try:
|
||||
text = requests.get(cls.url).text
|
||||
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
||||
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
|
||||
text = text.replace('void 0', 'null')
|
||||
def add_quotation_mark(match):
|
||||
return f'{match.group(1)}"{match.group(2)}":'
|
||||
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
|
||||
models = json.loads(text)
|
||||
cls.text_models = [model["id"] for model in models]
|
||||
cls.models = cls.text_models + cls.image_models
|
||||
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
||||
except Exception as e:
|
||||
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
|
||||
cls.models = [*cls.fallback_models]
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
|
||||
if cookies is None:
|
||||
cookies = get_cookies("huggingface.co", single_browser=True)
|
||||
if "hf-chat" in cookies:
|
||||
yield AuthResult(
|
||||
cookies=cookies,
|
||||
impersonate="chrome",
|
||||
headers=DEFAULT_HEADERS
|
||||
)
|
||||
return
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield RequestLogin(cls.__name__, login_url)
|
||||
yield AuthResult(
|
||||
**await get_args_from_nodriver(
|
||||
cls.url,
|
||||
proxy=proxy,
|
||||
wait_for='form[action="/chat/logout"]'
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def create_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool,
|
||||
auth_result: AuthResult,
|
||||
prompt: str = None,
|
||||
images: ImagesType = None,
|
||||
return_conversation: bool = False,
|
||||
conversation: Conversation = None,
|
||||
web_search: bool = False,
|
||||
cookies: Cookies = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
) -> AsyncResult:
|
||||
if not has_curl_cffi:
|
||||
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
||||
model = cls.get_model(model)
|
||||
if cookies is None:
|
||||
cookies = get_cookies("huggingface.co")
|
||||
|
||||
session = Session(cookies=cookies)
|
||||
session.headers = {
|
||||
'accept': '*/*',
|
||||
'accept-language': 'en',
|
||||
'cache-control': 'no-cache',
|
||||
'origin': 'https://huggingface.co',
|
||||
'pragma': 'no-cache',
|
||||
'priority': 'u=1, i',
|
||||
'referer': 'https://huggingface.co/chat/',
|
||||
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
|
||||
'sec-ch-ua-mobile': '?0',
|
||||
'sec-ch-ua-platform': '"macOS"',
|
||||
'sec-fetch-dest': 'empty',
|
||||
'sec-fetch-mode': 'cors',
|
||||
'sec-fetch-site': 'same-origin',
|
||||
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
|
||||
}
|
||||
session = Session(**auth_result.get_dict())
|
||||
|
||||
if conversation is None:
|
||||
if conversation is None or not hasattr(conversation, "models"):
|
||||
conversation = Conversation({})
|
||||
|
||||
if model not in conversation.models:
|
||||
conversationId = cls.create_conversation(session, model)
|
||||
messageId = cls.fetch_message_id(session, conversationId)
|
||||
conversation = Conversation(conversationId, messageId)
|
||||
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
inputs = format_prompt(messages)
|
||||
else:
|
||||
conversation.message_id = cls.fetch_message_id(session, conversation.conversation_id)
|
||||
conversationId = conversation.models[model]["conversationId"]
|
||||
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
|
||||
inputs = messages[-1]["content"]
|
||||
|
||||
debug.log(f"Use conversation: {conversation.conversation_id} Use message: {conversation.message_id}")
|
||||
debug.log(f"Use: {json.dumps(conversation.models[model])}")
|
||||
|
||||
settings = {
|
||||
"inputs": inputs,
|
||||
"id": conversation.message_id,
|
||||
"id": conversation.models[model]["messageId"],
|
||||
"is_retry": False,
|
||||
"is_continue": False,
|
||||
"web_search": web_search,
|
||||
|
|
@ -123,34 +158,27 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
|
||||
headers = {
|
||||
'accept': '*/*',
|
||||
'accept-language': 'en',
|
||||
'cache-control': 'no-cache',
|
||||
'origin': 'https://huggingface.co',
|
||||
'pragma': 'no-cache',
|
||||
'priority': 'u=1, i',
|
||||
'referer': f'https://huggingface.co/chat/conversation/{conversation.conversation_id}',
|
||||
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
|
||||
'sec-ch-ua-mobile': '?0',
|
||||
'sec-ch-ua-platform': '"macOS"',
|
||||
'sec-fetch-dest': 'empty',
|
||||
'sec-fetch-mode': 'cors',
|
||||
'sec-fetch-site': 'same-origin',
|
||||
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
|
||||
'referer': f'https://huggingface.co/chat/conversation/{conversationId}',
|
||||
}
|
||||
|
||||
data = CurlMime()
|
||||
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
|
||||
if images is not None:
|
||||
for image, filename in images:
|
||||
data.addpart(
|
||||
"files",
|
||||
filename=f"base64;{filename}",
|
||||
data=base64.b64encode(to_bytes(image))
|
||||
)
|
||||
|
||||
response = session.post(
|
||||
f'https://huggingface.co/chat/conversation/{conversation.conversation_id}',
|
||||
cookies=session.cookies,
|
||||
f'https://huggingface.co/chat/conversation/{conversationId}',
|
||||
headers=headers,
|
||||
multipart=data,
|
||||
stream=True
|
||||
)
|
||||
raise_for_status(response)
|
||||
|
||||
full_response = ""
|
||||
sources = None
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
|
|
@ -163,21 +191,20 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
if "type" not in line:
|
||||
raise RuntimeError(f"Response: {line}")
|
||||
elif line["type"] == "stream":
|
||||
token = line["token"].replace('\u0000', '')
|
||||
full_response += token
|
||||
if stream:
|
||||
yield token
|
||||
yield line["token"].replace('\u0000', '')
|
||||
elif line["type"] == "finalAnswer":
|
||||
break
|
||||
elif line["type"] == "file":
|
||||
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
|
||||
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
|
||||
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
|
||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
||||
yield ImageResponse(url, alt=prompt, options={"cookies": auth_result.cookies})
|
||||
elif line["type"] == "webSearch" and "sources" in line:
|
||||
sources = Sources(line["sources"])
|
||||
elif line["type"] == "title":
|
||||
yield TitleGeneration(line["title"])
|
||||
elif line["type"] == "reasoning":
|
||||
yield Reasoning(line.get("token"), line.get("status"))
|
||||
|
||||
full_response = full_response.replace('<|im_end|', '').strip()
|
||||
if not stream:
|
||||
yield full_response
|
||||
if sources is not None:
|
||||
yield sources
|
||||
|
||||
|
|
@ -189,8 +216,9 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
'model': model,
|
||||
}
|
||||
response = session.post('https://huggingface.co/chat/conversation', json=json_data)
|
||||
if response.status_code == 401:
|
||||
raise MissingAuthError(response.text)
|
||||
raise_for_status(response)
|
||||
|
||||
return response.json().get('conversationId')
|
||||
|
||||
@classmethod
|
||||
|
|
@ -215,6 +243,11 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
if not json_data:
|
||||
raise RuntimeError("Failed to parse response data")
|
||||
|
||||
if json_data["nodes"][-1]["type"] == "error":
|
||||
if json_data["nodes"][-1]["status"] == 403:
|
||||
raise MissingAuthError(json_data["nodes"][-1]["error"]["message"])
|
||||
raise ResponseError(json.dumps(json_data["nodes"][-1]))
|
||||
|
||||
data = json_data["nodes"][1]["data"]
|
||||
keys = data[data[0]["messages"]]
|
||||
message_keys = data[keys[-1]]
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
else:
|
||||
is_special = True
|
||||
debug.log(f"Special token: {is_special}")
|
||||
yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"])
|
||||
yield FinishReason("stop" if is_special else "length")
|
||||
else:
|
||||
if response.headers["content-type"].startswith("image/"):
|
||||
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
from .HuggingChat import HuggingChat
|
||||
from ...providers.types import Messages
|
||||
|
||||
class HuggingFaceAPI(OpenaiAPI):
|
||||
label = "HuggingFace (Inference API)"
|
||||
|
|
@ -11,6 +12,23 @@ class HuggingFaceAPI(OpenaiAPI):
|
|||
working = True
|
||||
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
default_vision_model = default_model
|
||||
models = [
|
||||
*HuggingChat.models
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs):
|
||||
HuggingChat.get_models()
|
||||
cls.models = HuggingChat.text_models
|
||||
cls.vision_models = HuggingChat.vision_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = None,
|
||||
**kwargs
|
||||
):
|
||||
if api_base is None:
|
||||
api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
|
||||
async for chunk in super().create_async_generator(model, messages, api_base=api_base, **kwargs):
|
||||
yield chunk
|
||||
|
|
@ -73,10 +73,11 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
raise MissingAuthError('Add a "api_key"')
|
||||
if api_base is None:
|
||||
api_base = cls.api_base
|
||||
if images is not None:
|
||||
if images is not None and messages:
|
||||
if not model and hasattr(cls, "default_vision_model"):
|
||||
model = cls.default_vision_model
|
||||
messages[-1]["content"] = [
|
||||
last_message = messages[-1].copy()
|
||||
last_message["content"] = [
|
||||
*[{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(image)}
|
||||
|
|
@ -86,6 +87,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
"text": messages[-1]["content"]
|
||||
}
|
||||
]
|
||||
messages[-1] = last_message
|
||||
async with StreamSession(
|
||||
proxy=proxy,
|
||||
headers=cls.get_headers(stream, api_key, headers),
|
||||
|
|
@ -106,10 +108,10 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
if api_endpoint is None:
|
||||
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
|
||||
async with session.post(api_endpoint, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
if not stream:
|
||||
if not stream or response.headers.get("content-type") == "application/json":
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
await raise_for_status(response)
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["message"] and choice["message"]["content"]:
|
||||
yield choice["message"]["content"].strip()
|
||||
|
|
@ -117,10 +119,11 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
yield ToolCalls(choice["message"]["tool_calls"])
|
||||
if "usage" in data:
|
||||
yield Usage(**data["usage"])
|
||||
finish = cls.read_finish_reason(choice)
|
||||
if finish is not None:
|
||||
yield finish
|
||||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
yield FinishReason(choice["finish_reason"])
|
||||
return
|
||||
else:
|
||||
await raise_for_status(response)
|
||||
first = True
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data: "):
|
||||
|
|
@ -137,15 +140,9 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
if delta:
|
||||
first = False
|
||||
yield delta
|
||||
finish = cls.read_finish_reason(choice)
|
||||
if finish is not None:
|
||||
yield finish
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def read_finish_reason(choice: dict) -> Optional[FinishReason]:
|
||||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
return FinishReason(choice["finish_reason"])
|
||||
yield FinishReason(choice["finish_reason"])
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
|
|
|
|||
|
|
@ -495,8 +495,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
"headers": cls._headers,
|
||||
"web_search": web_search,
|
||||
})
|
||||
actions = ["variant", "continue"] if conversation.finish_reason == "max_tokens" else ["variant"]
|
||||
yield FinishReason(conversation.finish_reason, actions=actions)
|
||||
yield FinishReason(conversation.finish_reason)
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
|
||||
|
|
|
|||
|
|
@ -1,61 +1,50 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ...typing import CreateResult, Messages
|
||||
from ..helper import filter_none
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
|
||||
models = {
|
||||
"theb-ai": "TheB.AI",
|
||||
"gpt-3.5-turbo": "GPT-3.5",
|
||||
"gpt-3.5-turbo-16k": "GPT-3.5-16K",
|
||||
"gpt-4-turbo": "GPT-4 Turbo",
|
||||
"gpt-4": "GPT-4",
|
||||
"gpt-4-32k": "GPT-4 32K",
|
||||
"claude-2": "Claude 2",
|
||||
"claude-1": "Claude",
|
||||
"claude-1-100k": "Claude 100K",
|
||||
"claude-instant-1": "Claude Instant",
|
||||
"claude-instant-1-100k": "Claude Instant 100K",
|
||||
"palm-2": "PaLM 2",
|
||||
"palm-2-codey": "Codey",
|
||||
"vicuna-13b-v1.5": "Vicuna v1.5 13B",
|
||||
"claude-3.5-sonnet": "Claude",
|
||||
"llama-2-7b-chat": "Llama 2 7B",
|
||||
"llama-2-13b-chat": "Llama 2 13B",
|
||||
"llama-2-70b-chat": "Llama 2 70B",
|
||||
"code-llama-7b": "Code Llama 7B",
|
||||
"code-llama-13b": "Code Llama 13B",
|
||||
"code-llama-34b": "Code Llama 34B",
|
||||
"qwen-7b-chat": "Qwen 7B"
|
||||
"qwen-2-72b": "Qwen"
|
||||
}
|
||||
|
||||
class ThebApi(OpenaiAPI):
|
||||
label = "TheB.AI API"
|
||||
url = "https://theb.ai"
|
||||
login_url = "https://beta.theb.ai/home"
|
||||
working = True
|
||||
api_base = "https://api.theb.ai/v1"
|
||||
needs_auth = True
|
||||
default_model = "gpt-3.5-turbo"
|
||||
models = list(models)
|
||||
default_model = "theb-ai"
|
||||
fallback_models = list(models)
|
||||
|
||||
@classmethod
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
temperature: float = 1,
|
||||
top_p: float = 1,
|
||||
temperature: float = None,
|
||||
top_p: float = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if "auth" in kwargs:
|
||||
kwargs["api_key"] = kwargs["auth"]
|
||||
system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"])
|
||||
if not system_message:
|
||||
system_message = "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."
|
||||
messages = [message for message in messages if message["role"] != "system"]
|
||||
data = {
|
||||
"model_params": {
|
||||
"system_prompt": system_message,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
"model_params": filter_none(
|
||||
system_prompt=system_message,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
}
|
||||
return super().create_async_generator(model, messages, extra_data=data, **kwargs)
|
||||
|
|
|
|||
|
|
@ -376,6 +376,29 @@ body:not(.white) a:visited{
|
|||
display: flex;
|
||||
}
|
||||
|
||||
.message .reasoning_text.final:not(.hidden), .message .reasoning_title {
|
||||
margin-bottom: var(--inner-gap);
|
||||
padding-bottom: var(--inner-gap);
|
||||
border-bottom: 1px solid var(--colour-3);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.message .reasoning_text.final {
|
||||
max-height: 1000px;
|
||||
transition: max-height 0.25s ease-in;
|
||||
}
|
||||
|
||||
.message .reasoning_text.final.hidden {
|
||||
transition: max-height 0.15s ease-out;
|
||||
max-height: 0;
|
||||
display: block;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.message .reasoning_title {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.message .user i {
|
||||
position: absolute;
|
||||
bottom: -6px;
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ let title_storage = {};
|
|||
let parameters_storage = {};
|
||||
let finish_storage = {};
|
||||
let usage_storage = {};
|
||||
let reasoning_storage = {}
|
||||
|
||||
messageInput.addEventListener("blur", () => {
|
||||
window.scrollTo(0, 0);
|
||||
|
|
@ -70,6 +71,17 @@ if (window.markdownit) {
|
|||
}
|
||||
}
|
||||
|
||||
function render_reasoning(reasoning, final = false) {
|
||||
return `<div class="reasoning_body">
|
||||
<div class="reasoning_title">
|
||||
<strong>Reasoning <i class="fa-solid fa-brain"></i>:</strong> ${escapeHtml(reasoning.status)}
|
||||
</div>
|
||||
<div class="reasoning_text${final ? " final hidden" : ""}">
|
||||
${markdown_render(reasoning.text)}
|
||||
</div>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
function filter_message(text) {
|
||||
return text.replaceAll(
|
||||
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
|
||||
|
|
@ -169,7 +181,7 @@ const get_message_el = (el) => {
|
|||
}
|
||||
|
||||
const register_message_buttons = async () => {
|
||||
document.querySelectorAll(".message .content .provider").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
const provider_forms = document.querySelector(".provider_forms");
|
||||
|
|
@ -192,7 +204,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -203,7 +215,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -226,7 +238,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -244,7 +256,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -270,7 +282,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -282,7 +294,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .continue_button").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -297,7 +309,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -307,7 +319,7 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-print").forEach(async (el) => {
|
||||
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
|
|
@ -323,6 +335,16 @@ const register_message_buttons = async () => {
|
|||
})
|
||||
}
|
||||
});
|
||||
|
||||
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
let text_el = el.parentElement.querySelector(".reasoning_text");
|
||||
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
|
||||
})
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const delete_conversations = async () => {
|
||||
|
|
@ -469,7 +491,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
|
|||
messages.forEach((message) => {
|
||||
message_copy = { ...message };
|
||||
if (last_message) {
|
||||
if (last_message["role"] == message["role"]) {
|
||||
if (last_message["role"] == message["role"] && message["role"] == "assistant") {
|
||||
message_copy["content"] = last_message["content"] + message_copy["content"];
|
||||
new_messages.pop();
|
||||
}
|
||||
|
|
@ -515,6 +537,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
|
|||
delete new_message.synthesize;
|
||||
delete new_message.finish;
|
||||
delete new_message.usage;
|
||||
delete new_message.reasoning;
|
||||
delete new_message.conversation;
|
||||
delete new_message.continue;
|
||||
// Append message to new messages
|
||||
|
|
@ -711,11 +734,21 @@ async function add_message_chunk(message, message_id, provider, scroll) {
|
|||
} else if (message.type == "title") {
|
||||
title_storage[message_id] = message.title;
|
||||
} else if (message.type == "login") {
|
||||
update_message(content_map, message_id, message.login, scroll);
|
||||
update_message(content_map, message_id, markdown_render(message.login), scroll);
|
||||
} else if (message.type == "finish") {
|
||||
finish_storage[message_id] = message.finish;
|
||||
} else if (message.type == "usage") {
|
||||
usage_storage[message_id] = message.usage;
|
||||
} else if (message.type == "reasoning") {
|
||||
if (!reasoning_storage[message_id]) {
|
||||
reasoning_storage[message_id] = message;
|
||||
reasoning_storage[message_id].text = "";
|
||||
} else if (message.status) {
|
||||
reasoning_storage[message_id].status = message.status;
|
||||
} else if (message.token) {
|
||||
reasoning_storage[message_id].text += message.token;
|
||||
}
|
||||
update_message(content_map, message_id, render_reasoning(reasoning_storage[message_id]), scroll);
|
||||
} else if (message.type == "parameters") {
|
||||
if (!parameters_storage[provider]) {
|
||||
parameters_storage[provider] = {};
|
||||
|
|
@ -846,6 +879,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
title_storage[message_id],
|
||||
finish_storage[message_id],
|
||||
usage_storage[message_id],
|
||||
reasoning_storage[message_id],
|
||||
action=="continue"
|
||||
);
|
||||
delete controller_storage[message_id];
|
||||
|
|
@ -1042,6 +1076,7 @@ function merge_messages(message1, message2) {
|
|||
const load_conversation = async (conversation_id, scroll=true) => {
|
||||
let conversation = await get_conversation(conversation_id);
|
||||
let messages = conversation?.items || [];
|
||||
console.debug("Conversation:", conversation)
|
||||
|
||||
if (!conversation) {
|
||||
return;
|
||||
|
|
@ -1098,11 +1133,8 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
let add_buttons = [];
|
||||
// Find buttons to add
|
||||
actions = ["variant"]
|
||||
if (item.finish && item.finish.actions) {
|
||||
actions = item.finish.actions
|
||||
}
|
||||
// Add continue button if possible
|
||||
if (item.role == "assistant" && !actions.includes("continue")) {
|
||||
if (item.role == "assistant") {
|
||||
let reason = "stop";
|
||||
// Read finish reason from conversation
|
||||
if (item.finish && item.finish.reason) {
|
||||
|
|
@ -1167,7 +1199,10 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
</div>
|
||||
<div class="content">
|
||||
${provider}
|
||||
<div class="content_inner">${markdown_render(buffer)}</div>
|
||||
<div class="content_inner">
|
||||
${item.reasoning ? render_reasoning(item.reasoning, true): ""}
|
||||
${markdown_render(buffer)}
|
||||
</div>
|
||||
<div class="count">
|
||||
${count_words_and_tokens(buffer, next_provider?.model, completion_tokens, prompt_tokens)}
|
||||
${add_buttons.join("")}
|
||||
|
|
@ -1298,6 +1333,7 @@ const add_message = async (
|
|||
title = null,
|
||||
finish = null,
|
||||
usage = null,
|
||||
reasoning = null,
|
||||
do_continue = false
|
||||
) => {
|
||||
const conversation = await get_conversation(conversation_id);
|
||||
|
|
@ -1329,6 +1365,9 @@ const add_message = async (
|
|||
if (usage) {
|
||||
new_message.usage = usage;
|
||||
}
|
||||
if (reasoning) {
|
||||
new_message.reasoning = reasoning;
|
||||
}
|
||||
if (do_continue) {
|
||||
new_message.continue = true;
|
||||
}
|
||||
|
|
@ -1604,23 +1643,24 @@ function count_words_and_tokens(text, model, completion_tokens, prompt_tokens) {
|
|||
|
||||
function update_message(content_map, message_id, content = null, scroll = true) {
|
||||
content_map.update_timeouts.push(setTimeout(() => {
|
||||
if (!content) content = message_storage[message_id];
|
||||
html = markdown_render(content);
|
||||
if (!content) {
|
||||
content = markdown_render(message_storage[message_id]);
|
||||
let lastElement, lastIndex = null;
|
||||
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
|
||||
const index = html.lastIndexOf(element)
|
||||
const index = content.lastIndexOf(element)
|
||||
if (index - element.length > lastIndex) {
|
||||
lastElement = element;
|
||||
lastIndex = index;
|
||||
}
|
||||
}
|
||||
if (lastIndex) {
|
||||
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
|
||||
content = content.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
|
||||
}
|
||||
}
|
||||
content_map.inner.innerHTML = content;
|
||||
if (error_storage[message_id]) {
|
||||
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${error_storage[message_id]}`);
|
||||
}
|
||||
content_map.inner.innerHTML = html;
|
||||
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
|
||||
highlight(content_map.inner);
|
||||
if (scroll) {
|
||||
|
|
@ -2132,9 +2172,9 @@ async function read_response(response, message_id, provider, scroll) {
|
|||
function get_api_key_by_provider(provider) {
|
||||
let api_key = null;
|
||||
if (provider) {
|
||||
api_key = document.getElementById(`${provider}-api_key`)?.id || null;
|
||||
if (api_key == null) {
|
||||
api_key = document.querySelector(`.${provider}-api_key`)?.id || null;
|
||||
if (api_key == null) {
|
||||
api_key = document.getElementById(`${provider}-api_key`)?.id || null;
|
||||
}
|
||||
if (api_key) {
|
||||
api_key = appStorage.getItem(api_key);
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ...tools.run_tools import iter_run_tools
|
|||
from ...Provider import ProviderUtils, __providers__
|
||||
from ...providers.base_provider import ProviderModelMixin
|
||||
from ...providers.retry_provider import IterListProvider
|
||||
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage
|
||||
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning
|
||||
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
|
||||
from ... import version, models
|
||||
from ... import ChatCompletion, get_model_and_provider
|
||||
|
|
@ -207,6 +207,8 @@ class Api:
|
|||
yield self._format_json("finish", chunk.get_dict())
|
||||
elif isinstance(chunk, Usage):
|
||||
yield self._format_json("usage", chunk.get_dict())
|
||||
elif isinstance(chunk, Reasoning):
|
||||
yield self._format_json("reasoning", token=chunk.token, status=chunk.status)
|
||||
else:
|
||||
yield self._format_json("content", str(chunk))
|
||||
if debug.logs:
|
||||
|
|
@ -219,10 +221,15 @@ class Api:
|
|||
if first:
|
||||
yield self.handle_provider(provider_handler, model)
|
||||
|
||||
def _format_json(self, response_type: str, content):
|
||||
def _format_json(self, response_type: str, content = None, **kwargs):
|
||||
if content is not None:
|
||||
return {
|
||||
'type': response_type,
|
||||
response_type: content
|
||||
response_type: content,
|
||||
}
|
||||
return {
|
||||
'type': response_type,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
def handle_provider(self, provider_handler, model):
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class Backend_Api(Api):
|
|||
return "Provider not found", 404
|
||||
return models
|
||||
|
||||
def _format_json(self, response_type: str, content) -> str:
|
||||
def _format_json(self, response_type: str, content = None, **kwargs) -> str:
|
||||
"""
|
||||
Formats and returns a JSON response.
|
||||
|
||||
|
|
@ -320,4 +320,4 @@ class Backend_Api(Api):
|
|||
Returns:
|
||||
str: A JSON formatted string.
|
||||
"""
|
||||
return json.dumps(super()._format_json(response_type, content)) + "\n"
|
||||
return json.dumps(super()._format_json(response_type, content, **kwargs)) + "\n"
|
||||
|
|
@ -340,7 +340,8 @@ class ProviderModelMixin:
|
|||
default_model: str = None
|
||||
models: list[str] = []
|
||||
model_aliases: dict[str, str] = {}
|
||||
image_models: list = None
|
||||
image_models: list = []
|
||||
vision_models: list = []
|
||||
last_model: str = None
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -89,9 +89,8 @@ class JsonMixin:
|
|||
self.__dict__ = {}
|
||||
|
||||
class FinishReason(ResponseType, JsonMixin):
|
||||
def __init__(self, reason: str, actions: list[str] = None) -> None:
|
||||
def __init__(self, reason: str) -> None:
|
||||
self.reason = reason
|
||||
self.actions = actions
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
|
@ -121,6 +120,14 @@ class TitleGeneration(ResponseType):
|
|||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class Reasoning(ResponseType):
|
||||
def __init__(self, token: str = None, status: str = None) -> None:
|
||||
self.token = token
|
||||
self.status = status
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "" if self.token is None else self.token
|
||||
|
||||
class Sources(ResponseType):
|
||||
def __init__(self, sources: list[dict[str, str]]) -> None:
|
||||
self.list = []
|
||||
|
|
|
|||
|
|
@ -78,25 +78,22 @@ async def get_args_from_nodriver(
|
|||
url: str,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
wait_for: str = None,
|
||||
cookies: Cookies = None
|
||||
) -> dict:
|
||||
if not has_nodriver:
|
||||
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
|
||||
browser = await get_nodriver(proxy=proxy)
|
||||
if debug.logging:
|
||||
print(f"Open nodriver with url: {url}")
|
||||
browser = await nodriver.start(
|
||||
browser_args=None if proxy is None else [f"--proxy-server={proxy}"],
|
||||
)
|
||||
domain = urlparse(url).netloc
|
||||
if cookies is None:
|
||||
cookies = {}
|
||||
else:
|
||||
await browser.cookies.set_all(get_cookie_params_from_dict(cookies, url=url, domain=domain))
|
||||
page = await browser.get(url)
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||
cookies[c.name] = c.value
|
||||
user_agent = await page.evaluate("window.navigator.userAgent")
|
||||
await page.wait_for("body:not(.no-js)", timeout=timeout)
|
||||
if wait_for is not None:
|
||||
await page.wait_for(wait_for, timeout=timeout)
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||
cookies[c.name] = c.value
|
||||
await page.close()
|
||||
|
|
@ -120,13 +117,13 @@ def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
|
|||
|
||||
async def get_nodriver(proxy: str = None, user_data_dir = "nodriver", browser_executable_path=None, **kwargs)-> Browser:
|
||||
if not has_nodriver:
|
||||
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
|
||||
raise MissingRequirementsError('Install "nodriver" and "platformdirs" package | pip install -U nodriver platformdirs')
|
||||
user_data_dir = user_config_dir(f"g4f-{user_data_dir}") if has_platformdirs else None
|
||||
if browser_executable_path is None:
|
||||
try:
|
||||
browser_executable_path = find_chrome_executable()
|
||||
except FileNotFoundError:
|
||||
# Default to Edge if Chrome is not found
|
||||
# Default to Edge if Chrome is not available.
|
||||
browser_executable_path = "C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe"
|
||||
if not os.path.exists(browser_executable_path):
|
||||
browser_executable_path = None
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
|||
return
|
||||
text = await response.text()
|
||||
if message is None:
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
|
||||
message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else text
|
||||
if message == "HTML content":
|
||||
if response.status == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
|
|
@ -46,7 +46,7 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
|
|||
if response.ok:
|
||||
return
|
||||
if message is None:
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else response.text
|
||||
message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else response.text
|
||||
if message == "HTML content":
|
||||
if response.status_code == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue