Support continue messages in Airforce

Add auth caching for OpenAI ChatGPT
Some provider improvments
This commit is contained in:
Heiner Lohaus 2025-01-03 20:35:46 +01:00
parent b0bc665621
commit 6e0bc147b5
17 changed files with 290 additions and 347 deletions

View file

@ -7,6 +7,7 @@ from typing import List
from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..providers.response import FinishReason, Usage
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
@ -232,17 +233,19 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
data = {
"messages": final_messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": stream,
}
if max_tokens != 512:
data["max_tokens"] = max_tokens
async with ClientSession(headers=headers) as session:
async with session.post(cls.api_endpoint_completions, json=data, proxy=proxy) as response:
await raise_for_status(response)
if stream:
idx = 0
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
@ -255,11 +258,18 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
chunk = cls._filter_response(delta['content'])
if chunk:
yield chunk
idx += 1
except json.JSONDecodeError:
continue
if idx == 512:
yield FinishReason("length")
else:
# Non-streaming response
result = await response.json()
if "usage" in result:
yield Usage(**result["usage"])
if result["usage"]["completion_tokens"] == 512:
yield FinishReason("length")
if 'choices' in result and result['choices']:
message = result['choices'][0].get('message', {})
content = message.get('content', '')
@ -273,7 +283,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
prompt: str = None,
proxy: str = None,
max_tokens: int = 4096,
max_tokens: int = 512,
temperature: float = 1,
top_p: float = 1,
stream: bool = True,

View file

@ -90,7 +90,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
cls._access_token, cls._cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
else:
raise h
yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies})
yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies if isinstance(cls._cookies, dict) else {c.name: c.value for c in cls._cookies}})
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}"
headers = {"authorization": f"Bearer {cls._access_token}"}
@ -191,6 +191,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break
elif msg.get("event") == "replaceText":
yield msg.get("text")
elif msg.get("event") == "error":
raise RuntimeError(f"Error: {msg}")
elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:

View file

@ -323,13 +323,13 @@ async def iter_filter_base64(chunks: AsyncIterator[bytes]) -> AsyncIterator[byte
async for chunk in chunks:
if is_started:
if end_with in chunk:
yield chunk.split(end_with, 1, maxsplit=1).pop(0)
yield chunk.split(end_with, maxsplit=1).pop(0)
break
else:
yield chunk
elif search_for in chunk:
is_started = True
yield chunk.split(search_for, 1, maxsplit=1).pop()
yield chunk.split(search_for, maxsplit=1).pop()
else:
raise ValueError(f"Response: {chunk}")

View file

@ -111,7 +111,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_new_tokens >= 1024:
if model_type == "gpt2" and max_tokens >= 1024:
params["max_new_tokens"] = 512
payload = {"inputs": inputs, "parameters": params, "stream": stream}

View file

@ -17,14 +17,14 @@ try:
except ImportError:
has_nodriver = False
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...requests import get_nodriver
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult
from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters
from ..helper import format_cookies
from ..openai.har_file import get_request_config
@ -85,7 +85,7 @@ UPLOAD_HEADERS = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
}
class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service"""
label = "OpenAI ChatGPT"
@ -104,6 +104,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
_cookies: Cookies = None
_expires: int = None
@classmethod
async def on_auth_async(cls, **kwargs) -> AuthResult:
if cls.needs_auth:
async for _ in cls.login():
pass
return AuthResult(
api_key=cls._api_key,
cookies=cls._cookies or RequestConfig.cookies or {},
headers=cls._headers or RequestConfig.headers or cls.get_default_headers(),
expires=cls._expires,
proof_token=RequestConfig.proof_token,
turnstile_token=RequestConfig.turnstile_token
)
@classmethod
def get_models(cls, proxy: str = None, timeout: int = 180) -> List[str]:
if not cls.models:
@ -135,7 +149,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
async def upload_images(
cls,
session: StreamSession,
headers: dict,
auth_result: AuthResult,
images: ImagesType,
) -> ImageRequest:
"""
@ -160,8 +174,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"use_case": "multimodal"
}
# Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
cls._update_request_args(session)
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response, "Create file failed")
image_data = {
**data,
@ -189,9 +203,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
json={},
headers=headers
headers=auth_result.headers
) as response:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
await raise_for_status(response, "Get download url failed")
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)
@ -248,7 +262,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return messages
@classmethod
async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict, prompt: str = None) -> ImageResponse:
async def get_generated_image(cls, auth_result: AuthResult, session: StreamSession, element: dict, prompt: str = None) -> ImageResponse:
try:
prompt = element["metadata"]["dalle"]["prompt"]
file_id = element["asset_pointer"].split("file-service://", 1)[1]
@ -257,8 +271,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
except Exception as e:
raise RuntimeError(f"No Image: {e.__class__.__name__}: {e}")
try:
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
cls._update_request_args(session)
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
download_url = (await response.json())["download_url"]
return ImageResponse(download_url, prompt)
@ -266,10 +280,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
raise RuntimeError(f"Error in downloading image: {e}")
@classmethod
async def create_async_generator(
async def create_authed(
cls,
model: str,
messages: Messages,
auth_result: AuthResult,
proxy: str = None,
timeout: int = 180,
auto_continue: bool = False,
@ -279,7 +294,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
conversation: Conversation = None,
images: ImagesType = None,
return_conversation: bool = False,
max_retries: int = 3,
max_retries: int = 0,
web_search: bool = False,
**kwargs
) -> AsyncResult:
@ -306,9 +321,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
if cls.needs_auth:
async for message in cls.login(proxy, **kwargs):
yield message
async with StreamSession(
proxy=proxy,
impersonate="chrome",
@ -319,15 +331,18 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if cls._headers is None:
cls._create_request_args(cls._cookies)
async with session.get(cls.url, headers=INIT_HEADERS) as response:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
await raise_for_status(response)
else:
print(cls._headers)
async with session.get(cls.url, headers=cls._headers) as response:
cls._update_request_args(session)
if cls._headers is None:
cls._create_request_args(auth_result.cookies, auth_result.headers)
if not cls._set_api_key(auth_result.api_key):
raise MissingAuthError("Access token is not valid")
async with session.get(cls.url, headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
try:
image_requests = await cls.upload_images(session, cls._headers, images) if images else None
image_requests = await cls.upload_images(session, auth_result, images) if images else None
except Exception as e:
debug.log("OpenaiChat: Upload image failed")
debug.log(f"{e.__class__.__name__}: {e}")
@ -345,36 +360,36 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
f"{cls.url}/backend-anon/sentinel/chat-requirements"
if cls._api_key is None else
f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"p": get_requirements_token(RequestConfig.proof_token) if RequestConfig.proof_token else None},
json={"p": None if auth_result.proof_token is None else get_requirements_token(auth_result.proof_token)},
headers=cls._headers
) as response:
if response.status == 401:
cls._headers = cls._api_key = None
else:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
await raise_for_status(response)
chat_requirements = await response.json()
need_turnstile = chat_requirements.get("turnstile", {}).get("required", False)
need_arkose = chat_requirements.get("arkose", {}).get("required", False)
chat_token = chat_requirements.get("token")
if need_arkose and RequestConfig.arkose_token is None:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
cls._set_api_key(RequestConfig.access_token)
if RequestConfig.arkose_token is None:
raise MissingAuthError("No arkose token found in .har file")
# if need_arkose and RequestConfig.arkose_token is None:
# await get_request_config(proxy)
# cls._create_request_args(auth_result.cookies, auth_result.headers)
# cls._set_api_key(auth_result.access_token)
# if auth_result.arkose_token is None:
# raise MissingAuthError("No arkose token found in .har file")
if "proofofwork" in chat_requirements:
if RequestConfig.proof_token is None:
RequestConfig.proof_token = get_config(cls._headers.get("user-agent"))
if auth_result.proof_token is None:
auth_result.proof_token = get_config(auth_result.headers.get("user-agent"))
proofofwork = generate_proof_token(
**chat_requirements["proofofwork"],
user_agent=cls._headers.get("user-agent"),
proof_token=RequestConfig.proof_token
user_agent=auth_result.headers.get("user-agent"),
proof_token=auth_result.proof_token
)
[debug.log(text) for text in (
f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}",
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
)]
@ -414,12 +429,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"content-type": "application/json",
"openai-sentinel-chat-requirements-token": chat_token,
}
if RequestConfig.arkose_token:
headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
#if RequestConfig.arkose_token:
# headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
if proofofwork is not None:
headers["openai-sentinel-proof-token"] = proofofwork
if need_turnstile and RequestConfig.turnstile_token is not None:
headers['openai-sentinel-turnstile-token'] = RequestConfig.turnstile_token
if need_turnstile and auth_result.turnstile_token is not None:
headers['openai-sentinel-turnstile-token'] = auth_result.turnstile_token
async with session.post(
f"{cls.url}/backend-anon/conversation"
if cls._api_key is None else
@ -427,7 +442,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
json=data,
headers=headers
) as response:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
if response.status in (403, 404) and max_retries > 0:
max_retries -= 1
debug.log(f"Retry: Error {response.status}: {await response.text()}")
@ -462,7 +477,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
yield sources
if return_conversation:
yield conversation
if not history_disabled and cls._api_key is not None:
if not history_disabled and auth_result.api_key is not None:
yield SynthesizeData(cls.__name__, {
"conversation_id": conversation.conversation_id,
"message_id": conversation.message_id,
@ -587,7 +602,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
print(RequestConfig.access_token)
if RequestConfig.access_token is not None or cls.needs_auth:
if not cls._set_api_key(RequestConfig.access_token):
raise NoValidHarFileError(f"Access token is not valid: {RequestConfig.access_token}")
@ -673,9 +687,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
cls._update_cookie_header()
@classmethod
def _update_request_args(cls, session: StreamSession):
def _update_request_args(cls, auth_result: AuthResult, session: StreamSession):
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
cls._cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
auth_result.cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
cls._update_cookie_header()
@classmethod

View file

@ -11,7 +11,7 @@ from .typing import Messages, CreateResult, AsyncResult, ImageType
from .errors import StreamNotSupportedError
from .cookies import get_cookies, set_cookies
from .providers.types import ProviderType
from .providers.helper import concat_chunks
from .providers.helper import concat_chunks, async_concat_chunks
from .client.service import get_model_and_provider
#Configure "g4f" logger
@ -47,8 +47,7 @@ class ChatCompletion:
if ignore_stream:
kwargs["ignore_stream"] = True
create_method = provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion
result = create_method(model, messages, stream=stream, **kwargs)
result = provider.get_create_function()(model, messages, stream=stream, **kwargs)
return result if stream else concat_chunks(result)
@ -72,11 +71,10 @@ class ChatCompletion:
if ignore_stream:
kwargs["ignore_stream"] = True
if stream:
if hasattr(provider, "create_async_authed_generator"):
return provider.create_async_authed_generator(model, messages, **kwargs)
elif hasattr(provider, "create_async_generator"):
return provider.create_async_generator(model, messages, **kwargs)
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
return provider.create_async(model, messages, **kwargs)
if not stream:
if hasattr(result, "__aiter__"):
result = async_concat_chunks(result)
return result

View file

@ -15,14 +15,14 @@ from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
from ..errors import NoImageResponseError
from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator, async_generator_to_list
from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
from .helper import find_stop, filter_json, filter_none, safe_aclose
from .. import debug
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
@ -236,7 +236,7 @@ class Completions:
kwargs["ignore_stream"] = True
response = iter_run_tools(
provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion,
provider.get_create_function(),
model,
messages,
stream=stream,
@ -248,12 +248,9 @@ class Completions:
),
**kwargs
)
if stream and hasattr(response, '__aiter__'):
# It's an async generator, wrap it into a sync iterator
response = to_sync_generator(response)
elif hasattr(response, '__aiter__'):
# If response is an async generator, collect it into a list
response = asyncio.run(async_generator_to_list(response))
if not hasattr(response, '__iter__'):
response = [response]
response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response, model, provider)
if stream:
@ -526,14 +523,8 @@ class AsyncCompletions:
kwargs["images"] = [(image, image_name)]
if ignore_stream:
kwargs["ignore_stream"] = True
if hasattr(provider, "create_async_authed_generator"):
create_handler = provider.create_async_authed_generator
if hasattr(provider, "create_async_generator"):
create_handler = provider.create_async_generator
else:
create_handler = provider.create_completion
response = async_iter_run_tools(
create_handler,
provider.get_async_create_function(),
model,
messages,
stream=stream,
@ -545,8 +536,6 @@ class AsyncCompletions:
),
**kwargs
)
if not hasattr(response, '__aiter__'):
response = to_async_iterator(response)
response = async_iter_response(response, stream, response_format, max_tokens, stop)
response = async_iter_append_model_and_provider(response, model, provider)
return response if stream else anext(response)

View file

@ -63,8 +63,3 @@ async def safe_aclose(generator: AsyncGenerator) -> None:
await generator.aclose()
except Exception as e:
logging.warning(f"Error while closing generator: {e}")
# Helper function to convert a synchronous iterator to an async iterator
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
for item in iterator:
yield item

View file

@ -175,6 +175,7 @@
}
}
</style>
<script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script>
</head>
<body>
<iframe id="background"></iframe>
@ -206,12 +207,24 @@
<p>Powered by the G4F framework</p>
</div>
<iframe id="stream-widget" class="stream" data-src="/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=news in " class="" frameborder="0"></iframe>
<iframe id="stream-widget" class="stream" frameborder="0"></iframe>
</div>
<script>
const iframe = document.getElementById('stream-widget');
iframe.src = iframe.dataset.src + navigator.language;
const iframe = document.getElementById('stream-widget');""
let search = (navigator.language == "de" ? "news in deutschland" : navigator.language == "en" ? "world news" : navigator.language);
if (Math.floor(Math.random() * 6) % 2 == 0) {
search = "xtekky/gpt4free releases";
}
const url = "/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=" + search;
iframe.src = url;
setTimeout(()=>iframe.classList.add('show'), 3000);
iframe.onload = () => {
const iframeDocument = iframe.contentDocument || iframe.contentWindow.document;
const iframeBody = iframeDocument.querySelector("body");
const iframeContent = iframeDocument.querySelector("pre");
const markdown = window.markdownit();
iframeBody.innerHTML = markdown.render(iframeContent.innerHTML);
}
(async () => {
const prompt = `

View file

@ -564,7 +564,7 @@ body:not(.white) a:visited{
height: 20px;
width: 100px;
transition: all 0.1s;
background: var(--colour-5);
background: var(--button-hover);
margin-top: -30px;
z-index: 1005;
padding: 6px;
@ -683,7 +683,7 @@ label[for="camera"] {
#messages form {
position: absolute;
width: 100%;
background: var(--colour-5);
background: var(--button-hover);
z-index: 2000;
}
@ -1354,7 +1354,7 @@ form .field.saved .fa-xmark {
.settings .label, form .label, .settings label, form label {
font-size: 15px;
margin-left: var(--inner-gap);
min-width: 120px;
min-width: 200px;
}
.settings .label, form .label {

View file

@ -2111,8 +2111,10 @@ if (SpeechRecognition) {
microLabel.classList.add("recognition");
startValue = messageInput.value;
lastDebounceTranscript = "";
messageInput.readOnly = true;
};
recognition.onend = function() {
messageInput.readOnly = false;
messageInput.focus();
};
recognition.onresult = function(event) {
@ -2138,7 +2140,7 @@ if (SpeechRecognition) {
}
};
microLabel.addEventListener("click", () => {
microLabel.addEventListener("click", (e) => {
if (microLabel.classList.contains("recognition")) {
recognition.stop();
microLabel.classList.remove("recognition");

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from asyncio import AbstractEventLoop, runners
from typing import Optional, Callable, AsyncGenerator, Generator
from typing import Optional, Callable, AsyncIterator, Iterator
from ..errors import NestAsyncioError
@ -37,10 +37,14 @@ def get_running_loop(check_nested: bool) -> Optional[AbstractEventLoop]:
async def await_callback(callback: Callable):
return await callback()
async def async_generator_to_list(generator: AsyncGenerator) -> list:
async def async_generator_to_list(generator: AsyncIterator) -> list:
return [item async for item in generator]
def to_sync_generator(generator: AsyncGenerator) -> Generator:
def to_sync_generator(generator: AsyncIterator, stream: bool = True) -> Iterator:
if not stream:
yield from asyncio.run(async_generator_to_list(generator))
return
loop = get_running_loop(check_nested=False)
new_loop = False
if loop is None:
@ -63,3 +67,18 @@ def to_sync_generator(generator: AsyncGenerator) -> Generator:
finally:
asyncio.set_event_loop(None)
loop.close()
# Helper function to convert a synchronous iterator to an async iterator
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
if isinstance(iterator, str):
yield iterator
elif hasattr(iterator, "__await__"):
yield await iterator
elif hasattr(iterator, "__aiter__"):
async for item in iterator:
yield item
elif hasattr(iterator, "__iter__"):
for item in iterator:
yield item
else:
yield iterator

View file

@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
import json
from inspect import signature, Parameter
from typing import Optional, Awaitable, _GenericAlias
from typing import Optional, _GenericAlias
from pathlib import Path
try:
from types import NoneType
@ -16,11 +16,11 @@ except ImportError:
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider
from .asyncio import get_running_loop, to_sync_generator
from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
from .response import BaseConversation, AuthResult
from .helper import concat_chunks, async_concat_chunks
from ..cookies import get_cookies_dir
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
from .. import debug
SAFE_PARAMETERS = [
@ -31,7 +31,7 @@ SAFE_PARAMETERS = [
"temperature", "top_k", "top_p",
"frequency_penalty", "presence_penalty",
"max_tokens", "max_new_tokens", "stop",
"api_key", "seed", "width", "height",
"api_key", "api_base", "seed", "width", "height",
"proof_token", "max_retries"
]
@ -63,9 +63,29 @@ PARAMETER_EXAMPLES = {
}
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
"""
@classmethod
@abstractmethod
def create_completion(
cls,
model: str,
messages: Messages,
stream: bool,
**kwargs
) -> CreateResult:
"""
Create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to use streaming.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the creation process.
"""
raise NotImplementedError()
@classmethod
async def create_async(
@ -92,16 +112,24 @@ class AbstractProvider(BaseProvider):
Returns:
str: The created result as a string.
"""
loop = loop or asyncio.get_running_loop()
loop = asyncio.get_running_loop() if loop is None else loop
def create_func() -> str:
return concat_chunks(cls.create_completion(model, messages, False, **kwargs))
return concat_chunks(cls.create_completion(model, messages, **kwargs))
return await asyncio.wait_for(
loop.run_in_executor(executor, create_func),
timeout=timeout
)
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
@classmethod
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
params = {name: parameter for name, parameter in signature(
@ -149,7 +177,7 @@ class AbstractProvider(BaseProvider):
) for name, param in {
**BASIC_PARAMETERS,
**params,
**{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
**{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
}.items()}
return params
@ -233,6 +261,14 @@ class AsyncProvider(AbstractProvider):
"""
raise NotImplementedError()
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
@ -262,30 +298,10 @@ class AsyncGeneratorProvider(AsyncProvider):
CreateResult: The result of the streaming completion creation.
"""
return to_sync_generator(
cls.create_async_generator(model, messages, stream=stream, **kwargs)
cls.create_async_generator(model, messages, stream=stream, **kwargs),
stream=stream
)
@classmethod
async def create_async(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously creates a result from a generator.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
@staticmethod
@abstractmethod
async def create_async_generator(
@ -311,11 +327,13 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
raise NotImplementedError()
create_authed = create_completion
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
create_authed_async = create_async
create_async_authed = create_async_generator
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
class ProviderModelMixin:
default_model: str = None
@ -357,96 +375,75 @@ class RaiseErrorMixin():
else:
raise ResponseError(data["error"])
class AuthedMixin():
class AsyncAuthedProvider(AsyncGeneratorProvider):
@classmethod
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
async def on_auth_async(cls, **kwargs) -> AuthResult:
if "api_key" not in kwargs:
raise MissingAuthError(f"API key is required for {cls.__name__}")
return None
return AuthResult()
@classmethod
def create_authed(
def on_auth(cls, **kwargs) -> AuthResult:
return asyncio.run(cls.on_auth_async(**kwargs))
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
auth_result = {}
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_completion(model, messages, **kwargs)
auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
else:
auth_result = cls.on_auth(**kwargs)
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
except (MissingAuthError, NoValidHarFileError):
if cache_file.exists():
cache_file.unlink()
auth_result = cls.on_auth(**kwargs)
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
finally:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedMixin(AuthedMixin):
@classmethod
async def create_async_authed(
async def create_async_generator(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
) -> AsyncResult:
try:
return await cls.create_async(model, messages, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
@classmethod
def create_async_authed_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> Awaitable[AsyncResult]:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
else:
auth_result = await cls.on_auth_async(**kwargs)
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response:
yield chunk
except (MissingAuthError, NoValidHarFileError):
if cache_file.exists():
cache_file.unlink()
auth_result = await cls.on_auth_async(**kwargs)
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response:
yield chunk
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import random
from ..typing import Type, List, CreateResult, Messages, AsyncResult
@ -8,8 +7,6 @@ from .types import BaseProvider, BaseRetryProvider, ProviderType
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
DEFAULT_TIMEOUT = 60
class IterListProvider(BaseRetryProvider):
def __init__(
self,
@ -56,10 +53,15 @@ class IterListProvider(BaseRetryProvider):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
for chunk in provider.create_completion(model, messages, stream, **kwargs):
if chunk:
yield chunk
started = True
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
if hasattr(response, "__iter__"):
for chunk in response:
if chunk:
yield chunk
started = True
elif response:
yield response
started = True
if started:
return
except Exception as e:
@ -70,41 +72,6 @@ class IterListProvider(BaseRetryProvider):
raise_exceptions(exceptions)
async def create_async(
self,
model: str,
messages: Messages,
ignored: list[str] = [],
**kwargs,
) -> str:
"""
Asynchronously create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Returns:
str: The result of the asynchronous completion.
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
exceptions = {}
for provider in self.get_providers(False, ignored):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
return chunk
except Exception as e:
exceptions[provider.__name__] = e
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
async def create_async_generator(
self,
model: str,
@ -121,22 +88,16 @@ class IterListProvider(BaseRetryProvider):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if not stream:
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
yield chunk
started = True
elif hasattr(provider, "create_async_generator"):
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
if hasattr(response, "__aiter__"):
async for chunk in response:
if chunk:
yield chunk
started = True
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
elif response:
response = await response
if response:
yield response
started = True
if started:
return
@ -148,6 +109,12 @@ class IterListProvider(BaseRetryProvider):
raise_exceptions(exceptions)
def get_create_function(self) -> callable:
return self.create_completion
def get_async_create_function(self) -> callable:
return self.create_async_generator
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
if self.shuffle:
@ -201,8 +168,14 @@ class RetryProvider(IterListProvider):
try:
if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
if hasattr(response, "__iter__"):
for chunk in response:
if chunk:
yield chunk
started = True
elif response:
yield response
started = True
if started:
return
@ -216,43 +189,6 @@ class RetryProvider(IterListProvider):
else:
yield from super().create_completion(model, messages, stream, **kwargs)
async def create_async(
self,
model: str,
messages: Messages,
**kwargs,
) -> str:
"""
Asynchronously create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Returns:
str: The result of the asynchronous completion.
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
exceptions = {}
if self.single_provider_retry:
provider = self.providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
else:
return await super().create_async(model, messages, **kwargs)
async def create_async_generator(
self,
model: str,
@ -269,22 +205,16 @@ class RetryProvider(IterListProvider):
for attempt in range(self.max_retries):
try:
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
if not stream:
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
yield chunk
started = True
elif hasattr(provider, "create_async_generator"):
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
if hasattr(response, "__aiter__"):
async for chunk in response:
if chunk:
yield chunk
started = True
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
response = await response
if response:
yield response
started = True
if started:
return

View file

@ -26,47 +26,23 @@ class BaseProvider(ABC):
supports_system_message: bool = False
params: str
@classmethod
@abstractmethod
def create_completion(
cls,
model: str,
messages: Messages,
stream: bool,
**kwargs
) -> CreateResult:
def get_create_function() -> callable:
"""
Create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to use streaming.
**kwargs: Additional keyword arguments.
Get the create function for the provider.
Returns:
CreateResult: The result of the creation process.
callable: The create function.
"""
raise NotImplementedError()
@classmethod
@abstractmethod
async def create_async(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
def get_async_create_function() -> callable:
"""
Asynchronously create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Get the async create function for the provider.
Returns:
str: The result of the creation process.
callable: The create function.
"""
raise NotImplementedError()

View file

@ -7,7 +7,7 @@ from typing import Optional, Callable, AsyncIterator
from ..typing import Messages
from ..providers.helper import filter_none
from ..client.helper import to_async_iterator
from ..providers.asyncio import to_async_iterator
from .web_search import do_search, get_search_message
from .files import read_bucket, get_bucket_dir
from .. import debug
@ -55,9 +55,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS
response = async_iter_callback(model=model, messages=messages, **kwargs)
if not hasattr(response, "__aiter__"):
response = to_async_iterator(response)
response = to_async_iterator(async_iter_callback(model=model, messages=messages, **kwargs))
async for chunk in response:
yield chunk

View file

@ -122,7 +122,7 @@ async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = No
with open(cache_file, "w") as f:
f.write(text)
return text
except ClientError:
except (ClientError, asyncio.TimeoutError):
return
async def search(query: str, max_results: int = 5, max_words: int = 2500, backend: str = "auto", add_text: bool = True, timeout: int = 5, region: str = "wt-wt") -> SearchResults:
@ -138,7 +138,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
max_results=max_results,
backend=backend,
):
if ".google.com" in result["href"]:
if ".google." in result["href"]:
continue
results.append(SearchResultEntry(
result["title"],