Add langchain integration

This commit is contained in:
hlohaus 2025-02-27 12:25:41 +01:00
parent 65265f3e51
commit 4e12f048b1
11 changed files with 133 additions and 88 deletions

View file

@ -14,7 +14,7 @@ from ..image import to_data_uri
from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Reasoning
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage
DEFAULT_HEADERS = {
'Accept': '*/*',
@ -63,6 +63,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"gpt-4o-mini": "claude",
"deepseek-chat": "claude-email",
"deepseek-r1": "deepseek-reasoner",
"gemini-2.0": "gemini",
"gemini-2.0-flash": "gemini",
"gemini-2.0-flash-thinking": "gemini-thinking",
@ -208,10 +209,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"enhance": str(enhance).lower(),
"safe": str(safe).lower()
}
params = {k: v for k, v in params.items() if v is not None}
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
prefix = f"{model}_{seed}" if seed is not None else model
url = f"{cls.image_api_endpoint}prompt/{prefix}_{quote_plus(prompt)}?{query}"
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items() if v is not None)
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
yield ImagePreview(url, prompt)
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
@ -266,7 +265,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"seed": seed,
"cache": cache
})
if "gemimi" in model:
data.pop("seed")
async with session.post(cls.text_api_endpoint, json=data) as response:
await raise_for_status(response)
result = await response.json()

View file

@ -36,7 +36,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
**kwargs
) -> AsyncResult:
if "images" not in kwargs and "deepseek" in model or random.random() >= 0.5:
if "tools" not in kwargs and "images" not in kwargs and "deepseek" in model or random.random() >= 0.5:
try:
is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):

View file

@ -8,7 +8,7 @@ import json
import base64
import time
import random
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, List
from typing import AsyncIterator, Iterator, Optional, Generator, Dict
from copy import copy
try:
@ -104,19 +104,16 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
synthesize_content_type = "audio/mpeg"
request_config = RequestConfig()
_api_key: str = None
_headers: dict = None
_cookies: Cookies = None
_expires: int = None
@classmethod
async def on_auth_async(cls, **kwargs) -> AsyncIterator:
async for chunk in cls.login():
async def on_auth_async(cls, proxy: str = None, **kwargs) -> AsyncIterator:
async for chunk in cls.login(proxy=proxy):
yield chunk
yield AuthResult(
api_key=cls._api_key,
cookies=cls._cookies or cls.request_config.cookies or {},
headers=cls._headers or cls.request_config.headers or cls.get_default_headers(),
api_key=cls.request_config.access_token,
cookies=cls.request_config.cookies or {},
headers=cls.request_config.headers or cls.get_default_headers(),
expires=cls._expires,
proof_token=cls.request_config.proof_token,
turnstile_token=cls.request_config.turnstile_token
@ -306,17 +303,17 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
) as session:
image_requests = None
if not cls.needs_auth:
if cls._headers is None:
cls._create_request_args(cls._cookies)
if not cls.request_config.headers:
cls._create_request_args(cls.request_config.cookies)
async with session.get(cls.url, headers=INIT_HEADERS) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
else:
if cls._headers is None and getattr(auth_result, "cookies", None):
if not cls.request_config.headers and getattr(auth_result, "cookies", None):
cls._create_request_args(auth_result.cookies, auth_result.headers)
if not cls._set_api_key(getattr(auth_result, "api_key", None)):
raise MissingAuthError("Access token is not valid")
async with session.get(cls.url, headers=cls._headers) as response:
async with session.get(cls.url, headers=cls.request_config.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
try:
@ -331,17 +328,17 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
conversation = copy(conversation)
if getattr(auth_result, "cookies", {}).get("oai-did") != getattr(conversation, "user_id", None):
conversation = Conversation(None, str(uuid.uuid4()))
if cls._api_key is None:
if cls.request_config.access_token is None:
auto_continue = False
conversation.finish_reason = None
sources = Sources([])
while conversation.finish_reason is None:
async with session.post(
f"{cls.url}/backend-anon/sentinel/chat-requirements"
if cls._api_key is None else
if cls.request_config.access_token is None else
f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"p": None if not getattr(auth_result, "proof_token", None) else get_requirements_token(getattr(auth_result, "proof_token", None))},
headers=cls._headers
headers=cls.request_config.headers
) as response:
if response.status in (401, 403):
auth_result.reset()
@ -407,7 +404,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
messages = messages if conversation_id is None else [messages[-1]]
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
headers = {
**cls._headers,
**cls.request_config.headers,
"accept": "text/event-stream",
"content-type": "application/json",
"openai-sentinel-chat-requirements-token": chat_token,
@ -420,7 +417,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
headers['openai-sentinel-turnstile-token'] = auth_result.turnstile_token
async with session.post(
f"{cls.url}/backend-anon/conversation"
if cls._api_key is None else
if cls.request_config.access_token is None else
f"{cls.url}/backend-api/conversation",
json=data,
headers=headers
@ -550,7 +547,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
async with session.get(
f"{cls.url}/backend-api/synthesize",
params=params,
headers=cls._headers
headers=cls.request_config.headers
) as response:
await raise_for_status(response)
async for chunk in response.iter_content():
@ -560,44 +557,29 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
async def login(
cls,
proxy: str = None,
api_key: str = None,
proof_token: str = None,
cookies: Cookies = None,
headers: dict = None,
**kwargs
) -> AsyncIterator:
if cls._expires is not None and (cls._expires - 60*10) < time.time():
cls._headers = cls._api_key = None
if cls._headers is None or headers is not None:
cls._headers = {} if headers is None else headers
if proof_token is not None:
cls.request_config.proof_token = proof_token
if cookies is not None:
cls.request_config.cookies = cookies
if api_key is not None:
cls.request_config.headers = cls.request_config.access_token = None
if cls.request_config.headers is None:
cls.request_config.headers = {}
if cls.request_config.access_token is not None:
cls._create_request_args(cls.request_config.cookies, cls.request_config.headers)
cls._set_api_key(api_key)
cls._set_api_key(cls.request_config.access_token)
else:
try:
await get_request_config(cls.request_config, proxy)
cls.request_config = await get_request_config(cls.request_config, proxy)
cls._create_request_args(cls.request_config.cookies, cls.request_config.headers)
if cls.request_config.access_token is not None or cls.needs_auth:
if not cls._set_api_key(cls.request_config.access_token):
raise NoValidHarFileError(f"Access token is not valid: {cls.request_config.access_token}")
except NoValidHarFileError:
if has_nodriver:
if cls._api_key is None:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield RequestLogin(cls.label, login_url)
if cls.request_config.access_token is None:
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
await cls.nodriver_auth(proxy)
else:
raise
yield Parameters(**{
"api_key": cls._api_key,
"proof_token": cls.request_config.proof_token,
"cookies": cls.request_config.cookies,
})
@classmethod
async def nodriver_auth(cls, proxy: str = None):
@ -615,7 +597,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if "OpenAI-Sentinel-Turnstile-Token" in event.request.headers:
cls.request_config.turnstile_token = event.request.headers["OpenAI-Sentinel-Turnstile-Token"]
if "Authorization" in event.request.headers:
cls._api_key = event.request.headers["Authorization"].split()[-1]
cls.request_config.access_token = event.request.headers["Authorization"].split()[-1]
elif event.request.url == arkose_url:
cls.request_config.arkose_request = arkReq(
arkURL=event.request.url,
@ -632,13 +614,13 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'")
await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()")
while True:
if cls._api_key is not None or not cls.needs_auth:
if cls.request_config.access_token is not None or not cls.needs_auth:
break
body = await page.evaluate("JSON.stringify(window.__remixContext)")
if body:
match = re.search(r'"accessToken":"(.*?)"', body)
if match:
cls._api_key = match.group(1)
cls.request_config.access_token = match.group(1)
break
await asyncio.sleep(1)
while True:
@ -649,7 +631,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls.request_config.cookies = await page.send(get_cookies([cls.url]))
await page.close()
cls._create_request_args(cls.request_config.cookies, cls.request_config.headers, user_agent=user_agent)
cls._set_api_key(cls._api_key)
cls._set_api_key(cls.request_config.access_token)
finally:
stop_browser()
@ -662,10 +644,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
@classmethod
def _create_request_args(cls, cookies: Cookies = None, headers: dict = None, user_agent: str = None):
cls._headers = cls.get_default_headers() if headers is None else headers
cls.request_config.headers = cls.get_default_headers() if headers is None else headers
if user_agent is not None:
cls._headers["user-agent"] = user_agent
cls._cookies = {} if cookies is None else cookies
cls.request_config.headers["user-agent"] = user_agent
cls.request_config.cookies = {} if cookies is None else cookies
cls._update_cookie_header()
@classmethod
@ -673,7 +655,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if hasattr(auth_result, "cookies"):
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
auth_result.cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
cls._cookies = auth_result.cookies
cls.request_config.cookies = auth_result.cookies
cls._update_cookie_header()
@classmethod
@ -686,15 +668,15 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if time.time() > cls._expires:
debug.log(f"OpenaiChat: API key is expired")
else:
cls._api_key = api_key
cls._headers["authorization"] = f"Bearer {api_key}"
cls.request_config.access_token = api_key
cls.request_config.headers["authorization"] = f"Bearer {api_key}"
return True
return False
@classmethod
def _update_cookie_header(cls):
if cls._cookies:
cls._headers["cookie"] = format_cookies(cls._cookies)
if cls.request_config.cookies:
cls.request_config.headers["cookie"] = format_cookies(cls.request_config.cookies)
class Conversation(JsonConversation):
"""

View file

@ -111,23 +111,24 @@ def iter_response(
break
idx += 1
if usage is None:
usage = Usage(completion_tokens=idx, total_tokens=idx)
usage = UsageModel.model_construct(completion_tokens=idx, total_tokens=idx)
else:
usage = UsageModel.model_construct(**usage.get_dict())
finish_reason = "stop" if finish_reason is None else finish_reason
if stream:
chat_completion = ChatCompletionChunk.model_construct(
None, finish_reason, completion_id, int(time.time()),
usage=usage
None, finish_reason, completion_id, int(time.time()), usage=usage
)
else:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
content = filter_json(content)
chat_completion = ChatCompletion.model_construct(
content, finish_reason, completion_id, int(time.time()),
usage=UsageModel.model_construct(**usage.get_dict()),
content, finish_reason, completion_id, int(time.time()), usage=usage,
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
)
if provider is not None:
@ -211,21 +212,23 @@ async def async_iter_response(
finish_reason = "stop" if finish_reason is None else finish_reason
if usage is None:
usage = Usage(completion_tokens=idx, total_tokens=idx)
usage = UsageModel.model_construct(completion_tokens=idx, total_tokens=idx)
else:
usage = UsageModel.model_construct(**usage.get_dict())
if stream:
chat_completion = ChatCompletionChunk.model_construct(
None, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict()
None, finish_reason, completion_id, int(time.time()), usage=usage
)
else:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
content = filter_json(content)
chat_completion = ChatCompletion.model_construct(
content, finish_reason, completion_id, int(time.time()),
usage=UsageModel.model_construct(**usage.get_dict()),
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
content, finish_reason, completion_id, int(time.time()), usage=usage,
**filter_none(
tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]
) if tool_calls is not None else {}
)
if provider is not None:
chat_completion.provider = provider.name

View file

@ -42,8 +42,8 @@ class UsageModel(BaseModel):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=TokenDetails.model_construct(**prompt_tokens_details) if prompt_tokens_details else None,
completion_tokens_details=TokenDetails.model_construct(**completion_tokens_details) if completion_tokens_details else None,
prompt_tokens_details=TokenDetails.model_construct(**prompt_tokens_details if prompt_tokens_details else {}),
completion_tokens_details=TokenDetails.model_construct(**completion_tokens_details if completion_tokens_details else {}),
**kwargs
)

View file

@ -32,6 +32,7 @@
--scrollbar: var(--colour-3);
--scrollbar-thumb: var(--blur-bg);
--button-hover: var(--colour-5);
--media-select: var(--colour-4);
--top: 50%;
--size: 70vw;
--blur: 35vw; /* Half of 70vw */
@ -520,7 +521,7 @@ body:not(.white) a:visited{
cursor: pointer;
user-select: none;
color: var(--colour-1);
background: var(--colour-4);
background: var(--media-select);
border: 1px solid var(--colour-1);
transition: all 0.2s ease;
width: auto;
@ -529,13 +530,14 @@ body:not(.white) a:visited{
}
.media-select label, .media-select button {
padding: 8px 12px;
padding: 20px 12px;
border-radius: var(--border-radius-1);
}
.media-select button.close {
order: 1000;
height: 32px;
padding: 8px 12px;
}
.count_total {
@ -1357,6 +1359,7 @@ ul {
--scrollbar-thumb: #ccc;
--button-hover: var(--colour-4);
--background: transparent;
--media-select: var(--colour-3);
}
.white .message .assistant .fa-xmark {

View file

@ -844,7 +844,7 @@ function is_stopped() {
return false;
}
const requestWakeLock = async (onVisibilityChange = false) => {
const requestWakeLock = async () => {
try {
wakeLock = await navigator.wakeLock.request('screen');
}
@ -890,7 +890,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
<i class="fa-solid fa-xmark"></i>
<i class="fa-regular fa-phone-arrow-down-left"></i>
</div>
<div class="content" id="gpt_${message_id}">
<div class="content">
<div class="provider" data-provider="${provider}"></div>
<div class="content_inner"><span class="cursor"></span></div>
<div class="count"></div>
@ -908,7 +908,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
controller_storage[message_id] = new AbortController();
let content_el = document.getElementById(`gpt_${message_id}`)
let content_el = message_el.querySelector('.content');
let content_map = content_storage[message_id] = {
container: message_el,
content: content_el,
@ -1986,7 +1986,7 @@ async function on_api() {
console.log("pressed enter");
prompt_lock = true;
setTimeout(()=>prompt_lock=false, 3000);
await handle_ask();
await handle_ask(!do_enter);
} else {
messageInput.style.height = messageInput.scrollHeight + "px";
}
@ -2777,7 +2777,9 @@ if (SpeechRecognition) {
buffer = "";
};
recognition.onend = function() {
if (buffer) {
messageInput.value = `${startValue ? startValue + "\n" : ""}${buffer}`;
}
if (microLabel.classList.contains("recognition")) {
recognition.start();
} else {

View file

@ -361,7 +361,7 @@ class ProviderModelMixin:
model = cls.model_aliases[model]
else:
if model not in cls.get_models(**kwargs) and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
cls.last_model = model
debug.last_model = model
return model
@ -373,9 +373,11 @@ class RaiseErrorMixin():
if "error_message" in data:
raise ResponseError(data["error_message"])
elif "error" in data:
if "code" in data["error"]:
if isinstance(data["error"], str):
raise ResponseError(data["error"])
elif "code" in data["error"]:
raise ResponseError("\n".join(
[e for e in [f'Error {data["error"]["code"]}: {data["error"]["message"]}', data["error"].get("failed_generation")] if e is not None]
[e for e in [f'Error {data["error"]["code"]}:{data["error"]["message"]}', data["error"].get("failed_generation")] if e is not None]
))
elif "message" in data["error"]:
raise ResponseError(data["error"]["message"])

View file

@ -23,9 +23,19 @@ def is_openai(text: str) -> bool:
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
if response.ok:
return
text = await response.text()
text = (await response.text()).strip()
if message is None:
is_html = response.headers.get("content-type", "").startswith("text/html") or text.startswith("<!DOCTYPE")
content_type = response.headers.get("content-type", "")
if content_type.startswith("application/json"):
try:
data = await response.json()
message = data.get("error")
if isinstance(message, dict):
message = data.get("message")
except Exception:
pass
else:
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else text
if message == "HTML content":
if response.status == 520:

41
g4f/tools/langchain.py Normal file
View file

@ -0,0 +1,41 @@
from __future__ import annotations
from typing import Any, Dict
from langchain_community.chat_models import openai
from langchain_community.chat_models.openai import convert_message_to_dict
from pydantic import Field
from g4f.client import AsyncClient, Client
from g4f.client.stubs import ChatCompletionMessage
def new_convert_message_to_dict(message: openai.BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatCompletionMessage):
message_dict = {"role": message.role, "content": message.content}
if message.tool_calls is not None:
message_dict["tool_calls"] = [{
"id": tool_call.id,
"type": tool_call.type,
"function": tool_call.function
} for tool_call in message.tool_calls]
if message_dict["content"] == "":
message_dict["content"] = None
else:
message_dict = convert_message_to_dict(message)
return message_dict
openai.convert_message_to_dict = new_convert_message_to_dict
class ChatAI(openai.ChatOpenAI):
model_name: str = Field(default="gpt-4o", alias="model")
@classmethod
def validate_environment(cls, values: dict) -> dict:
client_params = {
"api_key": values["g4f_api_key"] if "g4f_api_key" in values else None,
"provider": values["provider"] if "provider" in values else None,
}
values["client"] = Client(**client_params).chat.completions
values["async_client"] = AsyncClient(
**client_params
).chat.completions
return values

View file

@ -10,7 +10,7 @@ from typing import Optional, Callable, AsyncIterator
from ..typing import Messages
from ..providers.helper import filter_none
from ..providers.asyncio import to_async_iterator
from ..providers.response import Reasoning, FinishReason
from ..providers.response import Reasoning, FinishReason, Sources
from ..providers.types import ProviderType
from ..cookies import get_cookies_dir
from .web_search import do_search, get_search_message
@ -208,6 +208,8 @@ def iter_run_tools(
sources = None
yield chunk
continue
elif isinstance(chunk, Sources):
sources = None
if not isinstance(chunk, str):
yield chunk
continue