Support continue messages in Airforce

Add auth caching for OpenAI ChatGPT
Some provider improvments
This commit is contained in:
Heiner Lohaus 2025-01-03 20:35:46 +01:00
parent b0bc665621
commit 6e0bc147b5
17 changed files with 290 additions and 347 deletions

View file

@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
import json
from inspect import signature, Parameter
from typing import Optional, Awaitable, _GenericAlias
from typing import Optional, _GenericAlias
from pathlib import Path
try:
from types import NoneType
@ -16,11 +16,11 @@ except ImportError:
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider
from .asyncio import get_running_loop, to_sync_generator
from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
from .response import BaseConversation, AuthResult
from .helper import concat_chunks, async_concat_chunks
from ..cookies import get_cookies_dir
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
from .. import debug
SAFE_PARAMETERS = [
@ -31,7 +31,7 @@ SAFE_PARAMETERS = [
"temperature", "top_k", "top_p",
"frequency_penalty", "presence_penalty",
"max_tokens", "max_new_tokens", "stop",
"api_key", "seed", "width", "height",
"api_key", "api_base", "seed", "width", "height",
"proof_token", "max_retries"
]
@ -63,9 +63,29 @@ PARAMETER_EXAMPLES = {
}
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
"""
@classmethod
@abstractmethod
def create_completion(
cls,
model: str,
messages: Messages,
stream: bool,
**kwargs
) -> CreateResult:
"""
Create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to use streaming.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the creation process.
"""
raise NotImplementedError()
@classmethod
async def create_async(
@ -92,16 +112,24 @@ class AbstractProvider(BaseProvider):
Returns:
str: The created result as a string.
"""
loop = loop or asyncio.get_running_loop()
loop = asyncio.get_running_loop() if loop is None else loop
def create_func() -> str:
return concat_chunks(cls.create_completion(model, messages, False, **kwargs))
return concat_chunks(cls.create_completion(model, messages, **kwargs))
return await asyncio.wait_for(
loop.run_in_executor(executor, create_func),
timeout=timeout
)
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
@classmethod
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
params = {name: parameter for name, parameter in signature(
@ -149,7 +177,7 @@ class AbstractProvider(BaseProvider):
) for name, param in {
**BASIC_PARAMETERS,
**params,
**{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
**{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
}.items()}
return params
@ -233,6 +261,14 @@ class AsyncProvider(AbstractProvider):
"""
raise NotImplementedError()
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
@ -262,30 +298,10 @@ class AsyncGeneratorProvider(AsyncProvider):
CreateResult: The result of the streaming completion creation.
"""
return to_sync_generator(
cls.create_async_generator(model, messages, stream=stream, **kwargs)
cls.create_async_generator(model, messages, stream=stream, **kwargs),
stream=stream
)
@classmethod
async def create_async(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously creates a result from a generator.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
@staticmethod
@abstractmethod
async def create_async_generator(
@ -311,11 +327,13 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
raise NotImplementedError()
create_authed = create_completion
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
create_authed_async = create_async
create_async_authed = create_async_generator
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
class ProviderModelMixin:
default_model: str = None
@ -357,97 +375,76 @@ class RaiseErrorMixin():
else:
raise ResponseError(data["error"])
class AuthedMixin():
class AsyncAuthedProvider(AsyncGeneratorProvider):
@classmethod
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
async def on_auth_async(cls, **kwargs) -> AuthResult:
if "api_key" not in kwargs:
raise MissingAuthError(f"API key is required for {cls.__name__}")
return None
return AuthResult()
@classmethod
def create_authed(
def on_auth(cls, **kwargs) -> AuthResult:
return asyncio.run(cls.on_auth_async(**kwargs))
@classmethod
def get_create_function(cls) -> callable:
return cls.create_completion
@classmethod
def get_async_create_function(cls) -> callable:
return cls.create_async_generator
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
auth_result = {}
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_completion(model, messages, **kwargs)
auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
else:
auth_result = cls.on_auth(**kwargs)
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
except (MissingAuthError, NoValidHarFileError):
if cache_file.exists():
cache_file.unlink()
auth_result = cls.on_auth(**kwargs)
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
finally:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedMixin(AuthedMixin):
@classmethod
async def create_async_authed(
async def create_async_generator(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
) -> AsyncResult:
try:
return await cls.create_async(model, messages, **kwargs)
auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
else:
auth_result = await cls.on_auth_async(**kwargs)
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response:
yield chunk
except (MissingAuthError, NoValidHarFileError):
if cache_file.exists():
cache_file.unlink()
auth_result = await cls.on_auth_async(**kwargs)
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response:
yield chunk
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
@classmethod
def create_async_authed_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> Awaitable[AsyncResult]:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
cache_file.write_text(json.dumps(auth_result.get_dict()))