mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Support continue messages in Airforce
Add auth caching for OpenAI ChatGPT Some provider improvments
This commit is contained in:
parent
b0bc665621
commit
6e0bc147b5
17 changed files with 290 additions and 347 deletions
|
|
@ -7,6 +7,7 @@ from typing import List
|
|||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..image import ImageResponse
|
||||
from ..providers.response import FinishReason, Usage
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
|
||||
|
|
@ -232,17 +233,19 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
data = {
|
||||
"messages": final_messages,
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"stream": stream,
|
||||
}
|
||||
if max_tokens != 512:
|
||||
data["max_tokens"] = max_tokens
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
async with session.post(cls.api_endpoint_completions, json=data, proxy=proxy) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
if stream:
|
||||
idx = 0
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
if line.startswith('data: '):
|
||||
|
|
@ -255,11 +258,18 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
chunk = cls._filter_response(delta['content'])
|
||||
if chunk:
|
||||
yield chunk
|
||||
idx += 1
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if idx == 512:
|
||||
yield FinishReason("length")
|
||||
else:
|
||||
# Non-streaming response
|
||||
result = await response.json()
|
||||
if "usage" in result:
|
||||
yield Usage(**result["usage"])
|
||||
if result["usage"]["completion_tokens"] == 512:
|
||||
yield FinishReason("length")
|
||||
if 'choices' in result and result['choices']:
|
||||
message = result['choices'][0].get('message', {})
|
||||
content = message.get('content', '')
|
||||
|
|
@ -273,7 +283,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
messages: Messages,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
max_tokens: int = 4096,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1,
|
||||
top_p: float = 1,
|
||||
stream: bool = True,
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
cls._access_token, cls._cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
|
||||
else:
|
||||
raise h
|
||||
yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies})
|
||||
yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies if isinstance(cls._cookies, dict) else {c.name: c.value for c in cls._cookies}})
|
||||
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}"
|
||||
headers = {"authorization": f"Bearer {cls._access_token}"}
|
||||
|
||||
|
|
@ -191,6 +191,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
|
||||
elif msg.get("event") == "done":
|
||||
break
|
||||
elif msg.get("event") == "replaceText":
|
||||
yield msg.get("text")
|
||||
elif msg.get("event") == "error":
|
||||
raise RuntimeError(f"Error: {msg}")
|
||||
elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:
|
||||
|
|
|
|||
|
|
@ -323,13 +323,13 @@ async def iter_filter_base64(chunks: AsyncIterator[bytes]) -> AsyncIterator[byte
|
|||
async for chunk in chunks:
|
||||
if is_started:
|
||||
if end_with in chunk:
|
||||
yield chunk.split(end_with, 1, maxsplit=1).pop(0)
|
||||
yield chunk.split(end_with, maxsplit=1).pop(0)
|
||||
break
|
||||
else:
|
||||
yield chunk
|
||||
elif search_for in chunk:
|
||||
is_started = True
|
||||
yield chunk.split(search_for, 1, maxsplit=1).pop()
|
||||
yield chunk.split(search_for, maxsplit=1).pop()
|
||||
else:
|
||||
raise ValueError(f"Response: {chunk}")
|
||||
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
|
||||
inputs = get_inputs(messages, model_data, model_type, do_continue)
|
||||
debug.log(f"New len: {len(inputs)}")
|
||||
if model_type == "gpt2" and max_new_tokens >= 1024:
|
||||
if model_type == "gpt2" and max_tokens >= 1024:
|
||||
params["max_new_tokens"] = 512
|
||||
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ try:
|
|||
except ImportError:
|
||||
has_nodriver = False
|
||||
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests import StreamSession
|
||||
from ...requests import get_nodriver
|
||||
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, NoValidHarFileError
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters
|
||||
from ..helper import format_cookies
|
||||
from ..openai.har_file import get_request_config
|
||||
|
|
@ -85,7 +85,7 @@ UPLOAD_HEADERS = {
|
|||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
"""A class for creating and managing conversations with OpenAI chat service"""
|
||||
|
||||
label = "OpenAI ChatGPT"
|
||||
|
|
@ -104,6 +104,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
_cookies: Cookies = None
|
||||
_expires: int = None
|
||||
|
||||
@classmethod
|
||||
async def on_auth_async(cls, **kwargs) -> AuthResult:
|
||||
if cls.needs_auth:
|
||||
async for _ in cls.login():
|
||||
pass
|
||||
return AuthResult(
|
||||
api_key=cls._api_key,
|
||||
cookies=cls._cookies or RequestConfig.cookies or {},
|
||||
headers=cls._headers or RequestConfig.headers or cls.get_default_headers(),
|
||||
expires=cls._expires,
|
||||
proof_token=RequestConfig.proof_token,
|
||||
turnstile_token=RequestConfig.turnstile_token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, proxy: str = None, timeout: int = 180) -> List[str]:
|
||||
if not cls.models:
|
||||
|
|
@ -135,7 +149,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
async def upload_images(
|
||||
cls,
|
||||
session: StreamSession,
|
||||
headers: dict,
|
||||
auth_result: AuthResult,
|
||||
images: ImagesType,
|
||||
) -> ImageRequest:
|
||||
"""
|
||||
|
|
@ -160,8 +174,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"use_case": "multimodal"
|
||||
}
|
||||
# Post the image data to the service and get the image data
|
||||
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
|
||||
cls._update_request_args(session)
|
||||
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=auth_result.headers) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response, "Create file failed")
|
||||
image_data = {
|
||||
**data,
|
||||
|
|
@ -189,9 +203,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
async with session.post(
|
||||
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
|
||||
json={},
|
||||
headers=headers
|
||||
headers=auth_result.headers
|
||||
) as response:
|
||||
cls._update_request_args(session)
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response, "Get download url failed")
|
||||
image_data["download_url"] = (await response.json())["download_url"]
|
||||
return ImageRequest(image_data)
|
||||
|
|
@ -248,7 +262,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
return messages
|
||||
|
||||
@classmethod
|
||||
async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict, prompt: str = None) -> ImageResponse:
|
||||
async def get_generated_image(cls, auth_result: AuthResult, session: StreamSession, element: dict, prompt: str = None) -> ImageResponse:
|
||||
try:
|
||||
prompt = element["metadata"]["dalle"]["prompt"]
|
||||
file_id = element["asset_pointer"].split("file-service://", 1)[1]
|
||||
|
|
@ -257,8 +271,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
except Exception as e:
|
||||
raise RuntimeError(f"No Image: {e.__class__.__name__}: {e}")
|
||||
try:
|
||||
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
|
||||
cls._update_request_args(session)
|
||||
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=auth_result.headers) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
download_url = (await response.json())["download_url"]
|
||||
return ImageResponse(download_url, prompt)
|
||||
|
|
@ -266,10 +280,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
raise RuntimeError(f"Error in downloading image: {e}")
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
async def create_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
auth_result: AuthResult,
|
||||
proxy: str = None,
|
||||
timeout: int = 180,
|
||||
auto_continue: bool = False,
|
||||
|
|
@ -279,7 +294,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
conversation: Conversation = None,
|
||||
images: ImagesType = None,
|
||||
return_conversation: bool = False,
|
||||
max_retries: int = 3,
|
||||
max_retries: int = 0,
|
||||
web_search: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
|
|
@ -306,9 +321,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
Raises:
|
||||
RuntimeError: If an error occurs during processing.
|
||||
"""
|
||||
if cls.needs_auth:
|
||||
async for message in cls.login(proxy, **kwargs):
|
||||
yield message
|
||||
async with StreamSession(
|
||||
proxy=proxy,
|
||||
impersonate="chrome",
|
||||
|
|
@ -319,15 +331,18 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if cls._headers is None:
|
||||
cls._create_request_args(cls._cookies)
|
||||
async with session.get(cls.url, headers=INIT_HEADERS) as response:
|
||||
cls._update_request_args(session)
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
else:
|
||||
print(cls._headers)
|
||||
async with session.get(cls.url, headers=cls._headers) as response:
|
||||
cls._update_request_args(session)
|
||||
if cls._headers is None:
|
||||
cls._create_request_args(auth_result.cookies, auth_result.headers)
|
||||
if not cls._set_api_key(auth_result.api_key):
|
||||
raise MissingAuthError("Access token is not valid")
|
||||
async with session.get(cls.url, headers=auth_result.headers) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
try:
|
||||
image_requests = await cls.upload_images(session, cls._headers, images) if images else None
|
||||
image_requests = await cls.upload_images(session, auth_result, images) if images else None
|
||||
except Exception as e:
|
||||
debug.log("OpenaiChat: Upload image failed")
|
||||
debug.log(f"{e.__class__.__name__}: {e}")
|
||||
|
|
@ -345,36 +360,36 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
f"{cls.url}/backend-anon/sentinel/chat-requirements"
|
||||
if cls._api_key is None else
|
||||
f"{cls.url}/backend-api/sentinel/chat-requirements",
|
||||
json={"p": get_requirements_token(RequestConfig.proof_token) if RequestConfig.proof_token else None},
|
||||
json={"p": None if auth_result.proof_token is None else get_requirements_token(auth_result.proof_token)},
|
||||
headers=cls._headers
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
cls._headers = cls._api_key = None
|
||||
else:
|
||||
cls._update_request_args(session)
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
chat_requirements = await response.json()
|
||||
need_turnstile = chat_requirements.get("turnstile", {}).get("required", False)
|
||||
need_arkose = chat_requirements.get("arkose", {}).get("required", False)
|
||||
chat_token = chat_requirements.get("token")
|
||||
|
||||
if need_arkose and RequestConfig.arkose_token is None:
|
||||
await get_request_config(proxy)
|
||||
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
|
||||
cls._set_api_key(RequestConfig.access_token)
|
||||
if RequestConfig.arkose_token is None:
|
||||
raise MissingAuthError("No arkose token found in .har file")
|
||||
# if need_arkose and RequestConfig.arkose_token is None:
|
||||
# await get_request_config(proxy)
|
||||
# cls._create_request_args(auth_result.cookies, auth_result.headers)
|
||||
# cls._set_api_key(auth_result.access_token)
|
||||
# if auth_result.arkose_token is None:
|
||||
# raise MissingAuthError("No arkose token found in .har file")
|
||||
|
||||
if "proofofwork" in chat_requirements:
|
||||
if RequestConfig.proof_token is None:
|
||||
RequestConfig.proof_token = get_config(cls._headers.get("user-agent"))
|
||||
if auth_result.proof_token is None:
|
||||
auth_result.proof_token = get_config(auth_result.headers.get("user-agent"))
|
||||
proofofwork = generate_proof_token(
|
||||
**chat_requirements["proofofwork"],
|
||||
user_agent=cls._headers.get("user-agent"),
|
||||
proof_token=RequestConfig.proof_token
|
||||
user_agent=auth_result.headers.get("user-agent"),
|
||||
proof_token=auth_result.proof_token
|
||||
)
|
||||
[debug.log(text) for text in (
|
||||
f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}",
|
||||
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
||||
f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
|
||||
f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
|
||||
)]
|
||||
|
|
@ -414,12 +429,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"content-type": "application/json",
|
||||
"openai-sentinel-chat-requirements-token": chat_token,
|
||||
}
|
||||
if RequestConfig.arkose_token:
|
||||
headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
|
||||
#if RequestConfig.arkose_token:
|
||||
# headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
|
||||
if proofofwork is not None:
|
||||
headers["openai-sentinel-proof-token"] = proofofwork
|
||||
if need_turnstile and RequestConfig.turnstile_token is not None:
|
||||
headers['openai-sentinel-turnstile-token'] = RequestConfig.turnstile_token
|
||||
if need_turnstile and auth_result.turnstile_token is not None:
|
||||
headers['openai-sentinel-turnstile-token'] = auth_result.turnstile_token
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-anon/conversation"
|
||||
if cls._api_key is None else
|
||||
|
|
@ -427,7 +442,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
json=data,
|
||||
headers=headers
|
||||
) as response:
|
||||
cls._update_request_args(session)
|
||||
cls._update_request_args(auth_result, session)
|
||||
if response.status in (403, 404) and max_retries > 0:
|
||||
max_retries -= 1
|
||||
debug.log(f"Retry: Error {response.status}: {await response.text()}")
|
||||
|
|
@ -462,7 +477,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
yield sources
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
if not history_disabled and cls._api_key is not None:
|
||||
if not history_disabled and auth_result.api_key is not None:
|
||||
yield SynthesizeData(cls.__name__, {
|
||||
"conversation_id": conversation.conversation_id,
|
||||
"message_id": conversation.message_id,
|
||||
|
|
@ -587,7 +602,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
try:
|
||||
await get_request_config(proxy)
|
||||
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
|
||||
print(RequestConfig.access_token)
|
||||
if RequestConfig.access_token is not None or cls.needs_auth:
|
||||
if not cls._set_api_key(RequestConfig.access_token):
|
||||
raise NoValidHarFileError(f"Access token is not valid: {RequestConfig.access_token}")
|
||||
|
|
@ -673,9 +687,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
cls._update_cookie_header()
|
||||
|
||||
@classmethod
|
||||
def _update_request_args(cls, session: StreamSession):
|
||||
def _update_request_args(cls, auth_result: AuthResult, session: StreamSession):
|
||||
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
|
||||
cls._cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
|
||||
auth_result.cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
|
||||
cls._update_cookie_header()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from .typing import Messages, CreateResult, AsyncResult, ImageType
|
|||
from .errors import StreamNotSupportedError
|
||||
from .cookies import get_cookies, set_cookies
|
||||
from .providers.types import ProviderType
|
||||
from .providers.helper import concat_chunks
|
||||
from .providers.helper import concat_chunks, async_concat_chunks
|
||||
from .client.service import get_model_and_provider
|
||||
|
||||
#Configure "g4f" logger
|
||||
|
|
@ -47,8 +47,7 @@ class ChatCompletion:
|
|||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
|
||||
create_method = provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion
|
||||
result = create_method(model, messages, stream=stream, **kwargs)
|
||||
result = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||
|
||||
return result if stream else concat_chunks(result)
|
||||
|
||||
|
|
@ -72,11 +71,10 @@ class ChatCompletion:
|
|||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
|
||||
if stream:
|
||||
if hasattr(provider, "create_async_authed_generator"):
|
||||
return provider.create_async_authed_generator(model, messages, **kwargs)
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
return provider.create_async_generator(model, messages, **kwargs)
|
||||
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
|
||||
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
||||
|
||||
return provider.create_async(model, messages, **kwargs)
|
||||
if not stream:
|
||||
if hasattr(result, "__aiter__"):
|
||||
result = async_concat_chunks(result)
|
||||
|
||||
return result
|
||||
|
|
@ -15,14 +15,14 @@ from ..providers.types import ProviderType, BaseRetryProvider
|
|||
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
|
||||
from ..errors import NoImageResponseError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.asyncio import to_sync_generator, async_generator_to_list
|
||||
from ..providers.asyncio import to_sync_generator
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||
from .. import debug
|
||||
|
||||
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
|
|
@ -236,7 +236,7 @@ class Completions:
|
|||
kwargs["ignore_stream"] = True
|
||||
|
||||
response = iter_run_tools(
|
||||
provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion,
|
||||
provider.get_create_function(),
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
|
@ -248,12 +248,9 @@ class Completions:
|
|||
),
|
||||
**kwargs
|
||||
)
|
||||
if stream and hasattr(response, '__aiter__'):
|
||||
# It's an async generator, wrap it into a sync iterator
|
||||
response = to_sync_generator(response)
|
||||
elif hasattr(response, '__aiter__'):
|
||||
# If response is an async generator, collect it into a list
|
||||
response = asyncio.run(async_generator_to_list(response))
|
||||
if not hasattr(response, '__iter__'):
|
||||
response = [response]
|
||||
|
||||
response = iter_response(response, stream, response_format, max_tokens, stop)
|
||||
response = iter_append_model_and_provider(response, model, provider)
|
||||
if stream:
|
||||
|
|
@ -526,14 +523,8 @@ class AsyncCompletions:
|
|||
kwargs["images"] = [(image, image_name)]
|
||||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
if hasattr(provider, "create_async_authed_generator"):
|
||||
create_handler = provider.create_async_authed_generator
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
create_handler = provider.create_async_generator
|
||||
else:
|
||||
create_handler = provider.create_completion
|
||||
response = async_iter_run_tools(
|
||||
create_handler,
|
||||
provider.get_async_create_function(),
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
|
@ -545,8 +536,6 @@ class AsyncCompletions:
|
|||
),
|
||||
**kwargs
|
||||
)
|
||||
if not hasattr(response, '__aiter__'):
|
||||
response = to_async_iterator(response)
|
||||
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
||||
response = async_iter_append_model_and_provider(response, model, provider)
|
||||
return response if stream else anext(response)
|
||||
|
|
|
|||
|
|
@ -63,8 +63,3 @@ async def safe_aclose(generator: AsyncGenerator) -> None:
|
|||
await generator.aclose()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error while closing generator: {e}")
|
||||
|
||||
# Helper function to convert a synchronous iterator to an async iterator
|
||||
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
|
||||
for item in iterator:
|
||||
yield item
|
||||
|
|
@ -175,6 +175,7 @@
|
|||
}
|
||||
}
|
||||
</style>
|
||||
<script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<iframe id="background"></iframe>
|
||||
|
|
@ -206,12 +207,24 @@
|
|||
<p>Powered by the G4F framework</p>
|
||||
</div>
|
||||
|
||||
<iframe id="stream-widget" class="stream" data-src="/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=news in " class="" frameborder="0"></iframe>
|
||||
<iframe id="stream-widget" class="stream" frameborder="0"></iframe>
|
||||
</div>
|
||||
<script>
|
||||
const iframe = document.getElementById('stream-widget');
|
||||
iframe.src = iframe.dataset.src + navigator.language;
|
||||
const iframe = document.getElementById('stream-widget');""
|
||||
let search = (navigator.language == "de" ? "news in deutschland" : navigator.language == "en" ? "world news" : navigator.language);
|
||||
if (Math.floor(Math.random() * 6) % 2 == 0) {
|
||||
search = "xtekky/gpt4free releases";
|
||||
}
|
||||
const url = "/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=" + search;
|
||||
iframe.src = url;
|
||||
setTimeout(()=>iframe.classList.add('show'), 3000);
|
||||
iframe.onload = () => {
|
||||
const iframeDocument = iframe.contentDocument || iframe.contentWindow.document;
|
||||
const iframeBody = iframeDocument.querySelector("body");
|
||||
const iframeContent = iframeDocument.querySelector("pre");
|
||||
const markdown = window.markdownit();
|
||||
iframeBody.innerHTML = markdown.render(iframeContent.innerHTML);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
const prompt = `
|
||||
|
|
|
|||
|
|
@ -564,7 +564,7 @@ body:not(.white) a:visited{
|
|||
height: 20px;
|
||||
width: 100px;
|
||||
transition: all 0.1s;
|
||||
background: var(--colour-5);
|
||||
background: var(--button-hover);
|
||||
margin-top: -30px;
|
||||
z-index: 1005;
|
||||
padding: 6px;
|
||||
|
|
@ -683,7 +683,7 @@ label[for="camera"] {
|
|||
#messages form {
|
||||
position: absolute;
|
||||
width: 100%;
|
||||
background: var(--colour-5);
|
||||
background: var(--button-hover);
|
||||
z-index: 2000;
|
||||
}
|
||||
|
||||
|
|
@ -1354,7 +1354,7 @@ form .field.saved .fa-xmark {
|
|||
.settings .label, form .label, .settings label, form label {
|
||||
font-size: 15px;
|
||||
margin-left: var(--inner-gap);
|
||||
min-width: 120px;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.settings .label, form .label {
|
||||
|
|
|
|||
|
|
@ -2111,8 +2111,10 @@ if (SpeechRecognition) {
|
|||
microLabel.classList.add("recognition");
|
||||
startValue = messageInput.value;
|
||||
lastDebounceTranscript = "";
|
||||
messageInput.readOnly = true;
|
||||
};
|
||||
recognition.onend = function() {
|
||||
messageInput.readOnly = false;
|
||||
messageInput.focus();
|
||||
};
|
||||
recognition.onresult = function(event) {
|
||||
|
|
@ -2138,7 +2140,7 @@ if (SpeechRecognition) {
|
|||
}
|
||||
};
|
||||
|
||||
microLabel.addEventListener("click", () => {
|
||||
microLabel.addEventListener("click", (e) => {
|
||||
if (microLabel.classList.contains("recognition")) {
|
||||
recognition.stop();
|
||||
microLabel.classList.remove("recognition");
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop, runners
|
||||
from typing import Optional, Callable, AsyncGenerator, Generator
|
||||
from typing import Optional, Callable, AsyncIterator, Iterator
|
||||
|
||||
from ..errors import NestAsyncioError
|
||||
|
||||
|
|
@ -37,10 +37,14 @@ def get_running_loop(check_nested: bool) -> Optional[AbstractEventLoop]:
|
|||
async def await_callback(callback: Callable):
|
||||
return await callback()
|
||||
|
||||
async def async_generator_to_list(generator: AsyncGenerator) -> list:
|
||||
async def async_generator_to_list(generator: AsyncIterator) -> list:
|
||||
return [item async for item in generator]
|
||||
|
||||
def to_sync_generator(generator: AsyncGenerator) -> Generator:
|
||||
def to_sync_generator(generator: AsyncIterator, stream: bool = True) -> Iterator:
|
||||
if not stream:
|
||||
yield from asyncio.run(async_generator_to_list(generator))
|
||||
return
|
||||
|
||||
loop = get_running_loop(check_nested=False)
|
||||
new_loop = False
|
||||
if loop is None:
|
||||
|
|
@ -63,3 +67,18 @@ def to_sync_generator(generator: AsyncGenerator) -> Generator:
|
|||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
# Helper function to convert a synchronous iterator to an async iterator
|
||||
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
|
||||
if isinstance(iterator, str):
|
||||
yield iterator
|
||||
elif hasattr(iterator, "__await__"):
|
||||
yield await iterator
|
||||
elif hasattr(iterator, "__aiter__"):
|
||||
async for item in iterator:
|
||||
yield item
|
||||
elif hasattr(iterator, "__iter__"):
|
||||
for item in iterator:
|
||||
yield item
|
||||
else:
|
||||
yield iterator
|
||||
|
|
@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from abc import abstractmethod
|
||||
import json
|
||||
from inspect import signature, Parameter
|
||||
from typing import Optional, Awaitable, _GenericAlias
|
||||
from typing import Optional, _GenericAlias
|
||||
from pathlib import Path
|
||||
try:
|
||||
from types import NoneType
|
||||
|
|
@ -16,11 +16,11 @@ except ImportError:
|
|||
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .types import BaseProvider
|
||||
from .asyncio import get_running_loop, to_sync_generator
|
||||
from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
|
||||
from .response import BaseConversation, AuthResult
|
||||
from .helper import concat_chunks, async_concat_chunks
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
|
||||
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
|
|
@ -31,7 +31,7 @@ SAFE_PARAMETERS = [
|
|||
"temperature", "top_k", "top_p",
|
||||
"frequency_penalty", "presence_penalty",
|
||||
"max_tokens", "max_new_tokens", "stop",
|
||||
"api_key", "seed", "width", "height",
|
||||
"api_key", "api_base", "seed", "width", "height",
|
||||
"proof_token", "max_retries"
|
||||
]
|
||||
|
||||
|
|
@ -63,9 +63,29 @@ PARAMETER_EXAMPLES = {
|
|||
}
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to use streaming.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
|
|
@ -92,16 +112,24 @@ class AbstractProvider(BaseProvider):
|
|||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
loop = loop or asyncio.get_running_loop()
|
||||
loop = asyncio.get_running_loop() if loop is None else loop
|
||||
|
||||
def create_func() -> str:
|
||||
return concat_chunks(cls.create_completion(model, messages, False, **kwargs))
|
||||
return concat_chunks(cls.create_completion(model, messages, **kwargs))
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, create_func),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async
|
||||
|
||||
@classmethod
|
||||
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
|
||||
params = {name: parameter for name, parameter in signature(
|
||||
|
|
@ -149,7 +177,7 @@ class AbstractProvider(BaseProvider):
|
|||
) for name, param in {
|
||||
**BASIC_PARAMETERS,
|
||||
**params,
|
||||
**{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
|
||||
**{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
|
||||
}.items()}
|
||||
return params
|
||||
|
||||
|
|
@ -233,6 +261,14 @@ class AsyncProvider(AbstractProvider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async
|
||||
|
||||
class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
Provides asynchronous generator functionality for streaming results.
|
||||
|
|
@ -262,30 +298,10 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
CreateResult: The result of the streaming completion creation.
|
||||
"""
|
||||
return to_sync_generator(
|
||||
cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
cls.create_async_generator(model, messages, stream=stream, **kwargs),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously creates a result from a generator.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_async_generator(
|
||||
|
|
@ -311,11 +327,13 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
create_authed = create_completion
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
create_authed_async = create_async
|
||||
|
||||
create_async_authed = create_async_generator
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async_generator
|
||||
|
||||
class ProviderModelMixin:
|
||||
default_model: str = None
|
||||
|
|
@ -357,96 +375,75 @@ class RaiseErrorMixin():
|
|||
else:
|
||||
raise ResponseError(data["error"])
|
||||
|
||||
class AuthedMixin():
|
||||
class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||
|
||||
@classmethod
|
||||
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
|
||||
async def on_auth_async(cls, **kwargs) -> AuthResult:
|
||||
if "api_key" not in kwargs:
|
||||
raise MissingAuthError(f"API key is required for {cls.__name__}")
|
||||
return None
|
||||
return AuthResult()
|
||||
|
||||
@classmethod
|
||||
def create_authed(
|
||||
def on_auth(cls, **kwargs) -> AuthResult:
|
||||
return asyncio.run(cls.on_auth_async(**kwargs))
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async_generator
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
auth_result = {}
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return cls.create_completion(model, messages, **kwargs)
|
||||
auth_result = AuthResult()
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
else:
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
except (MissingAuthError, NoValidHarFileError):
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
finally:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
class AsyncAuthedMixin(AuthedMixin):
|
||||
@classmethod
|
||||
async def create_async_authed(
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
) -> AsyncResult:
|
||||
try:
|
||||
return await cls.create_async(model, messages, **kwargs)
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
|
||||
|
||||
@classmethod
|
||||
async def create_async_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
@classmethod
|
||||
def create_async_authed_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> Awaitable[AsyncResult]:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
auth_result = AuthResult()
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
else:
|
||||
auth_result = await cls.on_auth_async(**kwargs)
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
except (MissingAuthError, NoValidHarFileError):
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
auth_result = await cls.on_auth_async(**kwargs)
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from ..typing import Type, List, CreateResult, Messages, AsyncResult
|
||||
|
|
@ -8,8 +7,6 @@ from .types import BaseProvider, BaseRetryProvider, ProviderType
|
|||
from .. import debug
|
||||
from ..errors import RetryProviderError, RetryNoProviderError
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
class IterListProvider(BaseRetryProvider):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -56,10 +53,15 @@ class IterListProvider(BaseRetryProvider):
|
|||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
for chunk in provider.create_completion(model, messages, stream, **kwargs):
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||
if hasattr(response, "__iter__"):
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
elif response:
|
||||
yield response
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
|
|
@ -70,41 +72,6 @@ class IterListProvider(BaseRetryProvider):
|
|||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
ignored: list[str] = [],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion using available providers.
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
Returns:
|
||||
str: The result of the asynchronous completion.
|
||||
Raises:
|
||||
Exception: Any exception encountered during the asynchronous completion process.
|
||||
"""
|
||||
exceptions = {}
|
||||
|
||||
for provider in self.get_providers(False, ignored):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
|
||||
)
|
||||
if chunk:
|
||||
return chunk
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -121,22 +88,16 @@ class IterListProvider(BaseRetryProvider):
|
|||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
if not stream:
|
||||
chunk = await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
|
||||
)
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
|
||||
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
||||
if hasattr(response, "__aiter__"):
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
else:
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
elif response:
|
||||
response = await response
|
||||
if response:
|
||||
yield response
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
|
|
@ -148,6 +109,12 @@ class IterListProvider(BaseRetryProvider):
|
|||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
def get_create_function(self) -> callable:
|
||||
return self.create_completion
|
||||
|
||||
def get_async_create_function(self) -> callable:
|
||||
return self.create_async_generator
|
||||
|
||||
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
|
||||
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
|
||||
if self.shuffle:
|
||||
|
|
@ -201,8 +168,14 @@ class RetryProvider(IterListProvider):
|
|||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||
if hasattr(response, "__iter__"):
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
elif response:
|
||||
yield response
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
|
|
@ -216,43 +189,6 @@ class RetryProvider(IterListProvider):
|
|||
else:
|
||||
yield from super().create_completion(model, messages, stream, **kwargs)
|
||||
|
||||
async def create_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion using available providers.
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
Returns:
|
||||
str: The result of the asynchronous completion.
|
||||
Raises:
|
||||
Exception: Any exception encountered during the asynchronous completion process.
|
||||
"""
|
||||
exceptions = {}
|
||||
|
||||
if self.single_provider_retry:
|
||||
provider = self.providers[0]
|
||||
self.last_provider = provider
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
return await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", 60),
|
||||
)
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
raise_exceptions(exceptions)
|
||||
else:
|
||||
return await super().create_async(model, messages, **kwargs)
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -269,22 +205,16 @@ class RetryProvider(IterListProvider):
|
|||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
if not stream:
|
||||
chunk = await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
|
||||
)
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
|
||||
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
|
||||
if hasattr(response, "__aiter__"):
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
else:
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
response = await response
|
||||
if response:
|
||||
yield response
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -26,47 +26,23 @@ class BaseProvider(ABC):
|
|||
supports_system_message: bool = False
|
||||
params: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
def get_create_function() -> callable:
|
||||
"""
|
||||
Create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to use streaming.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Get the create function for the provider.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation process.
|
||||
callable: The create function.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
def get_async_create_function() -> callable:
|
||||
"""
|
||||
Asynchronously create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Get the async create function for the provider.
|
||||
|
||||
Returns:
|
||||
str: The result of the creation process.
|
||||
callable: The create function.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Optional, Callable, AsyncIterator
|
|||
|
||||
from ..typing import Messages
|
||||
from ..providers.helper import filter_none
|
||||
from ..client.helper import to_async_iterator
|
||||
from ..providers.asyncio import to_async_iterator
|
||||
from .web_search import do_search, get_search_message
|
||||
from .files import read_bucket, get_bucket_dir
|
||||
from .. import debug
|
||||
|
|
@ -55,9 +55,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
|
|||
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
||||
|
||||
response = async_iter_callback(model=model, messages=messages, **kwargs)
|
||||
if not hasattr(response, "__aiter__"):
|
||||
response = to_async_iterator(response)
|
||||
response = to_async_iterator(async_iter_callback(model=model, messages=messages, **kwargs))
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = No
|
|||
with open(cache_file, "w") as f:
|
||||
f.write(text)
|
||||
return text
|
||||
except ClientError:
|
||||
except (ClientError, asyncio.TimeoutError):
|
||||
return
|
||||
|
||||
async def search(query: str, max_results: int = 5, max_words: int = 2500, backend: str = "auto", add_text: bool = True, timeout: int = 5, region: str = "wt-wt") -> SearchResults:
|
||||
|
|
@ -138,7 +138,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
|
|||
max_results=max_results,
|
||||
backend=backend,
|
||||
):
|
||||
if ".google.com" in result["href"]:
|
||||
if ".google." in result["href"]:
|
||||
continue
|
||||
results.append(SearchResultEntry(
|
||||
result["title"],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue