mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Refactor model handling and improve timeout functionality
- Removed empty string mapping from model_map in AnyModelProviderMixin. - Updated clean_name function to exclude 'chat' from version patterns. - Added stream_timeout parameter to AsyncGeneratorProvider for more flexible timeout handling. - Enhanced chunk yielding in AsyncAuthedProvider to support stream_timeout, allowing for better control over asynchronous responses.
This commit is contained in:
parent
e09c08969a
commit
1edd0fff17
11 changed files with 682 additions and 586 deletions
|
|
@ -3,6 +3,8 @@
|
||||||
|
|
||||||
G4F_API_KEY=
|
G4F_API_KEY=
|
||||||
G4F_PROXY=
|
G4F_PROXY=
|
||||||
|
G4F_TIMEOUT=
|
||||||
|
G4F_STREAM_TIMEOUT
|
||||||
|
|
||||||
HUGGINGFACE_API_KEY=
|
HUGGINGFACE_API_KEY=
|
||||||
POLLINATIONS_API_KEY=
|
POLLINATIONS_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class GLM(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
cls.api_key = response.json().get("token")
|
cls.api_key = response.json().get("token")
|
||||||
response = requests.get(f"{cls.url}/api/models", headers={"Authorization": f"Bearer {cls.api_key}"})
|
response = requests.get(f"{cls.url}/api/models", headers={"Authorization": f"Bearer {cls.api_key}"})
|
||||||
data = response.json().get("data", [])
|
data = response.json().get("data", [])
|
||||||
cls.model_aliases = {data.get("name"): data.get("id") for data in data}
|
cls.model_aliases = {data.get("name", "").replace("\u4efb\u52a1\u4e13\u7528", "ChatGLM"): data.get("id") for data in data}
|
||||||
cls.models = list(cls.model_aliases.keys())
|
cls.models = list(cls.model_aliases.keys())
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class DuckDuckGo(AbstractProvider, ProviderModelMixin):
|
||||||
url = "https://duckduckgo.com/aichat"
|
url = "https://duckduckgo.com/aichat"
|
||||||
api_base = "https://duckduckgo.com/duckchat/v1/"
|
api_base = "https://duckduckgo.com/duckchat/v1/"
|
||||||
|
|
||||||
working = has_requirements
|
working = False
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_PORT = 1337
|
DEFAULT_PORT = 1337
|
||||||
DEFAULT_TIMEOUT = 600
|
DEFAULT_TIMEOUT = 600
|
||||||
|
DEFAULT_STREAM_TIMEOUT = 15
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|
@ -99,6 +100,8 @@ async def lifespan(app: FastAPI):
|
||||||
if not AppConfig.ignore_cookie_files:
|
if not AppConfig.ignore_cookie_files:
|
||||||
read_cookie_files()
|
read_cookie_files()
|
||||||
AppConfig.g4f_api_key = os.environ.get("G4F_API_KEY", AppConfig.g4f_api_key)
|
AppConfig.g4f_api_key = os.environ.get("G4F_API_KEY", AppConfig.g4f_api_key)
|
||||||
|
AppConfig.timeout = os.environ.get("G4F_TIMEOUT", AppConfig.timeout)
|
||||||
|
AppConfig.stream_timeout = os.environ.get("G4F_STREAM_TIMEOUT", AppConfig.stream_timeout)
|
||||||
yield
|
yield
|
||||||
if has_nodriver:
|
if has_nodriver:
|
||||||
for browser in util.get_registered_instances():
|
for browser in util.get_registered_instances():
|
||||||
|
|
@ -133,7 +136,7 @@ def create_app():
|
||||||
if AppConfig.gui:
|
if AppConfig.gui:
|
||||||
if not has_a2wsgi:
|
if not has_a2wsgi:
|
||||||
raise MissingRequirementsError("a2wsgi is required for GUI. Install it with: pip install a2wsgi")
|
raise MissingRequirementsError("a2wsgi is required for GUI. Install it with: pip install a2wsgi")
|
||||||
gui_app = WSGIMiddleware(get_gui_app(AppConfig.demo, AppConfig.timeout))
|
gui_app = WSGIMiddleware(get_gui_app(AppConfig.demo, AppConfig.timeout, AppConfig.stream_timeout))
|
||||||
app.mount("/", gui_app)
|
app.mount("/", gui_app)
|
||||||
|
|
||||||
if AppConfig.ignored_providers:
|
if AppConfig.ignored_providers:
|
||||||
|
|
@ -185,6 +188,7 @@ class AppConfig:
|
||||||
gui: bool = False
|
gui: bool = False
|
||||||
demo: bool = False
|
demo: bool = False
|
||||||
timeout: int = DEFAULT_TIMEOUT
|
timeout: int = DEFAULT_TIMEOUT
|
||||||
|
stream_timeout: int = DEFAULT_STREAM_TIMEOUT
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_config(cls, **data):
|
def set_config(cls, **data):
|
||||||
|
|
@ -418,6 +422,8 @@ class Api:
|
||||||
config.conversation_id = conversation_id
|
config.conversation_id = conversation_id
|
||||||
if config.timeout is None:
|
if config.timeout is None:
|
||||||
config.timeout = AppConfig.timeout
|
config.timeout = AppConfig.timeout
|
||||||
|
if config.stream_timeout is None:
|
||||||
|
config.stream_timeout = AppConfig.stream_timeout
|
||||||
if credentials is not None and credentials.credentials != "secret":
|
if credentials is not None and credentials.credentials != "secret":
|
||||||
config.api_key = credentials.credentials
|
config.api_key = credentials.credentials
|
||||||
|
|
||||||
|
|
@ -451,7 +457,7 @@ class Api:
|
||||||
"model": AppConfig.model,
|
"model": AppConfig.model,
|
||||||
"provider": AppConfig.provider,
|
"provider": AppConfig.provider,
|
||||||
"proxy": AppConfig.proxy,
|
"proxy": AppConfig.proxy,
|
||||||
**config.dict(exclude_none=True),
|
**(config.model_dump(exclude_none=True) if hasattr(config, "model_dump") else config.dict(exclude_none=True)),
|
||||||
**{
|
**{
|
||||||
"conversation_id": None,
|
"conversation_id": None,
|
||||||
"conversation": conversation,
|
"conversation": conversation,
|
||||||
|
|
@ -474,7 +480,7 @@ class Api:
|
||||||
self.conversations[config.conversation_id] = {}
|
self.conversations[config.conversation_id] = {}
|
||||||
self.conversations[config.conversation_id][config.provider] = chunk
|
self.conversations[config.conversation_id][config.provider] = chunk
|
||||||
else:
|
else:
|
||||||
yield f"data: {chunk.json()}\n\n"
|
yield f"data: {chunk.model_dump_json() if hasattr(chunk, 'model_dump_json') else chunk.json()}\n\n"
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ class RequestConfig(BaseModel):
|
||||||
proxy: Optional[str] = None
|
proxy: Optional[str] = None
|
||||||
conversation: Optional[dict] = None
|
conversation: Optional[dict] = None
|
||||||
timeout: Optional[int] = None
|
timeout: Optional[int] = None
|
||||||
|
stream_timeout: Optional[int] = None
|
||||||
tool_calls: list = Field(default=[], examples=[[
|
tool_calls: list = Field(default=[], examples=[[
|
||||||
{
|
{
|
||||||
"function": {
|
"function": {
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,13 @@ try:
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
import_error = e
|
import_error = e
|
||||||
|
|
||||||
def get_gui_app(demo: bool = False, timeout: int = None):
|
def get_gui_app(demo: bool = False, timeout: int = None, stream_timeout: int = None):
|
||||||
if import_error is not None:
|
if import_error is not None:
|
||||||
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
|
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
|
||||||
app = create_app()
|
app = create_app()
|
||||||
app.demo = demo
|
app.demo = demo
|
||||||
app.timeout = timeout
|
app.timeout = timeout
|
||||||
|
app.stream_timeout = stream_timeout
|
||||||
|
|
||||||
site = Website(app)
|
site = Website(app)
|
||||||
for route in site.routes:
|
for route in site.routes:
|
||||||
|
|
|
||||||
|
|
@ -196,6 +196,8 @@ class Backend_Api(Api):
|
||||||
json_data['media'] = media
|
json_data['media'] = media
|
||||||
if app.timeout:
|
if app.timeout:
|
||||||
json_data['timeout'] = app.timeout
|
json_data['timeout'] = app.timeout
|
||||||
|
if app.stream_timeout:
|
||||||
|
json_data['stream_timeout'] = app.stream_timeout
|
||||||
if app.demo and not json_data.get("provider"):
|
if app.demo and not json_data.get("provider"):
|
||||||
model = json_data.get("model")
|
model = json_data.get("model")
|
||||||
if model != "default" and model in models.demo_models:
|
if model != "default" and model in models.demo_models:
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from .Provider import (
|
||||||
HuggingSpace,
|
HuggingSpace,
|
||||||
Grok,
|
Grok,
|
||||||
DeepseekAI_JanusPro7b,
|
DeepseekAI_JanusPro7b,
|
||||||
|
GLM,
|
||||||
Kimi,
|
Kimi,
|
||||||
LambdaChat,
|
LambdaChat,
|
||||||
Mintlify,
|
Mintlify,
|
||||||
|
|
@ -25,6 +26,7 @@ from .Provider import (
|
||||||
PerplexityLabs,
|
PerplexityLabs,
|
||||||
PollinationsAI,
|
PollinationsAI,
|
||||||
PollinationsImage,
|
PollinationsImage,
|
||||||
|
Qwen,
|
||||||
TeachAnything,
|
TeachAnything,
|
||||||
Together,
|
Together,
|
||||||
WeWordle,
|
WeWordle,
|
||||||
|
|
@ -159,8 +161,10 @@ default = Model(
|
||||||
DeepInfra,
|
DeepInfra,
|
||||||
OperaAria,
|
OperaAria,
|
||||||
Startnest,
|
Startnest,
|
||||||
LambdaChat,
|
GLM,
|
||||||
|
Kimi,
|
||||||
PollinationsAI,
|
PollinationsAI,
|
||||||
|
Qwen,
|
||||||
Together,
|
Together,
|
||||||
Chatai,
|
Chatai,
|
||||||
WeWordle,
|
WeWordle,
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -220,6 +220,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
||||||
|
|
||||||
cls.video_models.append("video")
|
cls.video_models.append("video")
|
||||||
cls.model_map["video"] = {"Video": "video"}
|
cls.model_map["video"] = {"Video": "video"}
|
||||||
|
del cls.model_map[""]
|
||||||
cls.audio_models = [*cls.audio_models]
|
cls.audio_models = [*cls.audio_models]
|
||||||
|
|
||||||
# Create a mapping of parent providers to their children
|
# Create a mapping of parent providers to their children
|
||||||
|
|
@ -415,7 +416,7 @@ def clean_name(name: str) -> str:
|
||||||
name = re.sub(r'-\d{2}-\d{2}', '', name)
|
name = re.sub(r'-\d{2}-\d{2}', '', name)
|
||||||
name = re.sub(r'-[0-9a-f]{8}$', '', name)
|
name = re.sub(r'-[0-9a-f]{8}$', '', name)
|
||||||
# Version patterns
|
# Version patterns
|
||||||
name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf|free|tput)$', '', name)
|
name = re.sub(r'-(instruct|preview|experimental|v\d+|fp8|bf16|hf|free|tput)$', '', name)
|
||||||
# Other replacements
|
# Other replacements
|
||||||
name = name.replace("_", ".")
|
name = name.replace("_", ".")
|
||||||
name = name.replace("c4ai-", "")
|
name = name.replace("c4ai-", "")
|
||||||
|
|
|
||||||
|
|
@ -291,6 +291,7 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
timeout: int = None,
|
timeout: int = None,
|
||||||
|
stream_timeout: int = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> CreateResult:
|
) -> CreateResult:
|
||||||
"""
|
"""
|
||||||
|
|
@ -308,7 +309,7 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||||
"""
|
"""
|
||||||
return to_sync_generator(
|
return to_sync_generator(
|
||||||
cls.create_async_generator(model, messages, **kwargs),
|
cls.create_async_generator(model, messages, **kwargs),
|
||||||
timeout=timeout
|
timeout=timeout if stream_timeout is None else stream_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -482,7 +483,7 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||||
auth_result = chunk
|
auth_result = chunk
|
||||||
else:
|
else:
|
||||||
yield chunk
|
yield chunk
|
||||||
for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)):
|
for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs), kwargs.get("stream_timeout", kwargs.get("timeout"))):
|
||||||
if cache_file is not None:
|
if cache_file is not None:
|
||||||
cls.write_cache_file(cache_file, auth_result)
|
cls.write_cache_file(cache_file, auth_result)
|
||||||
cache_file = None
|
cache_file = None
|
||||||
|
|
@ -500,8 +501,15 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||||
try:
|
try:
|
||||||
auth_result = cls.get_auth_result()
|
auth_result = cls.get_auth_result()
|
||||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||||
async for chunk in response:
|
if "stream_timeout" in kwargs:
|
||||||
yield chunk
|
while True:
|
||||||
|
try:
|
||||||
|
yield await asyncio.wait_for(response.__anext__(), timeout=kwargs["stream_timeout"])
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
async for chunk in response:
|
||||||
|
yield chunk
|
||||||
except (MissingAuthError, NoValidHarFileError, CloudflareError):
|
except (MissingAuthError, NoValidHarFileError, CloudflareError):
|
||||||
# if cache_file.exists():
|
# if cache_file.exists():
|
||||||
# cache_file.unlink()
|
# cache_file.unlink()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue