mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Support continue messages in Airforce
Add auth caching for OpenAI ChatGPT Some provider improvments
This commit is contained in:
parent
b0bc665621
commit
6e0bc147b5
17 changed files with 290 additions and 347 deletions
|
|
@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from abc import abstractmethod
|
||||
import json
|
||||
from inspect import signature, Parameter
|
||||
from typing import Optional, Awaitable, _GenericAlias
|
||||
from typing import Optional, _GenericAlias
|
||||
from pathlib import Path
|
||||
try:
|
||||
from types import NoneType
|
||||
|
|
@ -16,11 +16,11 @@ except ImportError:
|
|||
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .types import BaseProvider
|
||||
from .asyncio import get_running_loop, to_sync_generator
|
||||
from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
|
||||
from .response import BaseConversation, AuthResult
|
||||
from .helper import concat_chunks, async_concat_chunks
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
|
||||
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
|
|
@ -31,7 +31,7 @@ SAFE_PARAMETERS = [
|
|||
"temperature", "top_k", "top_p",
|
||||
"frequency_penalty", "presence_penalty",
|
||||
"max_tokens", "max_new_tokens", "stop",
|
||||
"api_key", "seed", "width", "height",
|
||||
"api_key", "api_base", "seed", "width", "height",
|
||||
"proof_token", "max_retries"
|
||||
]
|
||||
|
||||
|
|
@ -63,9 +63,29 @@ PARAMETER_EXAMPLES = {
|
|||
}
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to use streaming.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
|
|
@ -92,16 +112,24 @@ class AbstractProvider(BaseProvider):
|
|||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
loop = loop or asyncio.get_running_loop()
|
||||
loop = asyncio.get_running_loop() if loop is None else loop
|
||||
|
||||
def create_func() -> str:
|
||||
return concat_chunks(cls.create_completion(model, messages, False, **kwargs))
|
||||
return concat_chunks(cls.create_completion(model, messages, **kwargs))
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, create_func),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async
|
||||
|
||||
@classmethod
|
||||
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
|
||||
params = {name: parameter for name, parameter in signature(
|
||||
|
|
@ -149,7 +177,7 @@ class AbstractProvider(BaseProvider):
|
|||
) for name, param in {
|
||||
**BASIC_PARAMETERS,
|
||||
**params,
|
||||
**{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")},
|
||||
**{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
|
||||
}.items()}
|
||||
return params
|
||||
|
||||
|
|
@ -233,6 +261,14 @@ class AsyncProvider(AbstractProvider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async
|
||||
|
||||
class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
Provides asynchronous generator functionality for streaming results.
|
||||
|
|
@ -262,30 +298,10 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
CreateResult: The result of the streaming completion creation.
|
||||
"""
|
||||
return to_sync_generator(
|
||||
cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
cls.create_async_generator(model, messages, stream=stream, **kwargs),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously creates a result from a generator.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_async_generator(
|
||||
|
|
@ -311,11 +327,13 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
create_authed = create_completion
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
create_authed_async = create_async
|
||||
|
||||
create_async_authed = create_async_generator
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async_generator
|
||||
|
||||
class ProviderModelMixin:
|
||||
default_model: str = None
|
||||
|
|
@ -357,97 +375,76 @@ class RaiseErrorMixin():
|
|||
else:
|
||||
raise ResponseError(data["error"])
|
||||
|
||||
class AuthedMixin():
|
||||
class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||
|
||||
@classmethod
|
||||
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
|
||||
async def on_auth_async(cls, **kwargs) -> AuthResult:
|
||||
if "api_key" not in kwargs:
|
||||
raise MissingAuthError(f"API key is required for {cls.__name__}")
|
||||
return None
|
||||
return AuthResult()
|
||||
|
||||
@classmethod
|
||||
def create_authed(
|
||||
def on_auth(cls, **kwargs) -> AuthResult:
|
||||
return asyncio.run(cls.on_auth_async(**kwargs))
|
||||
|
||||
@classmethod
|
||||
def get_create_function(cls) -> callable:
|
||||
return cls.create_completion
|
||||
|
||||
@classmethod
|
||||
def get_async_create_function(cls) -> callable:
|
||||
return cls.create_async_generator
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
auth_result = {}
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return cls.create_completion(model, messages, **kwargs)
|
||||
auth_result = AuthResult()
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
else:
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
except (MissingAuthError, NoValidHarFileError):
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
finally:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
class AsyncAuthedMixin(AuthedMixin):
|
||||
@classmethod
|
||||
async def create_async_authed(
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
) -> AsyncResult:
|
||||
try:
|
||||
return await cls.create_async(model, messages, **kwargs)
|
||||
auth_result = AuthResult()
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
else:
|
||||
auth_result = await cls.on_auth_async(**kwargs)
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
except (MissingAuthError, NoValidHarFileError):
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
auth_result = await cls.on_auth_async(**kwargs)
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
|
||||
|
||||
@classmethod
|
||||
async def create_async_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
|
||||
@classmethod
|
||||
def create_async_authed_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> Awaitable[AsyncResult]:
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
auth_result = {}
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = json.load(f)
|
||||
return cls.create_completion(model, messages, **kwargs, **auth_result)
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
try:
|
||||
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
finally:
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
Loading…
Add table
Add a link
Reference in a new issue