Refactor model handling and improve timeout functionality

- Removed empty string mapping from model_map in AnyModelProviderMixin.
- Updated clean_name function to exclude 'chat' from version patterns.
- Added stream_timeout parameter to AsyncGeneratorProvider for more flexible timeout handling.
- Enhanced chunk yielding in AsyncAuthedProvider to support stream_timeout, allowing for better control over asynchronous responses.
This commit is contained in:
hlohaus 2025-09-04 18:11:05 +02:00
parent e09c08969a
commit 1edd0fff17
11 changed files with 682 additions and 586 deletions

View file

@ -3,6 +3,8 @@
G4F_API_KEY= G4F_API_KEY=
G4F_PROXY= G4F_PROXY=
G4F_TIMEOUT=
G4F_STREAM_TIMEOUT
HUGGINGFACE_API_KEY= HUGGINGFACE_API_KEY=
POLLINATIONS_API_KEY= POLLINATIONS_API_KEY=

View file

@ -23,7 +23,7 @@ class GLM(AsyncGeneratorProvider, ProviderModelMixin):
cls.api_key = response.json().get("token") cls.api_key = response.json().get("token")
response = requests.get(f"{cls.url}/api/models", headers={"Authorization": f"Bearer {cls.api_key}"}) response = requests.get(f"{cls.url}/api/models", headers={"Authorization": f"Bearer {cls.api_key}"})
data = response.json().get("data", []) data = response.json().get("data", [])
cls.model_aliases = {data.get("name"): data.get("id") for data in data} cls.model_aliases = {data.get("name", "").replace("\u4efb\u52a1\u4e13\u7528", "ChatGLM"): data.get("id") for data in data}
cls.models = list(cls.model_aliases.keys()) cls.models = list(cls.model_aliases.keys())
return cls.models return cls.models

View file

@ -17,7 +17,7 @@ class DuckDuckGo(AbstractProvider, ProviderModelMixin):
url = "https://duckduckgo.com/aichat" url = "https://duckduckgo.com/aichat"
api_base = "https://duckduckgo.com/duckchat/v1/" api_base = "https://duckduckgo.com/duckchat/v1/"
working = has_requirements working = False
supports_stream = True supports_stream = True
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True

View file

@ -92,6 +92,7 @@ logger = logging.getLogger(__name__)
DEFAULT_PORT = 1337 DEFAULT_PORT = 1337
DEFAULT_TIMEOUT = 600 DEFAULT_TIMEOUT = 600
DEFAULT_STREAM_TIMEOUT = 15
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@ -99,6 +100,8 @@ async def lifespan(app: FastAPI):
if not AppConfig.ignore_cookie_files: if not AppConfig.ignore_cookie_files:
read_cookie_files() read_cookie_files()
AppConfig.g4f_api_key = os.environ.get("G4F_API_KEY", AppConfig.g4f_api_key) AppConfig.g4f_api_key = os.environ.get("G4F_API_KEY", AppConfig.g4f_api_key)
AppConfig.timeout = os.environ.get("G4F_TIMEOUT", AppConfig.timeout)
AppConfig.stream_timeout = os.environ.get("G4F_STREAM_TIMEOUT", AppConfig.stream_timeout)
yield yield
if has_nodriver: if has_nodriver:
for browser in util.get_registered_instances(): for browser in util.get_registered_instances():
@ -133,7 +136,7 @@ def create_app():
if AppConfig.gui: if AppConfig.gui:
if not has_a2wsgi: if not has_a2wsgi:
raise MissingRequirementsError("a2wsgi is required for GUI. Install it with: pip install a2wsgi") raise MissingRequirementsError("a2wsgi is required for GUI. Install it with: pip install a2wsgi")
gui_app = WSGIMiddleware(get_gui_app(AppConfig.demo, AppConfig.timeout)) gui_app = WSGIMiddleware(get_gui_app(AppConfig.demo, AppConfig.timeout, AppConfig.stream_timeout))
app.mount("/", gui_app) app.mount("/", gui_app)
if AppConfig.ignored_providers: if AppConfig.ignored_providers:
@ -185,6 +188,7 @@ class AppConfig:
gui: bool = False gui: bool = False
demo: bool = False demo: bool = False
timeout: int = DEFAULT_TIMEOUT timeout: int = DEFAULT_TIMEOUT
stream_timeout: int = DEFAULT_STREAM_TIMEOUT
@classmethod @classmethod
def set_config(cls, **data): def set_config(cls, **data):
@ -418,6 +422,8 @@ class Api:
config.conversation_id = conversation_id config.conversation_id = conversation_id
if config.timeout is None: if config.timeout is None:
config.timeout = AppConfig.timeout config.timeout = AppConfig.timeout
if config.stream_timeout is None:
config.stream_timeout = AppConfig.stream_timeout
if credentials is not None and credentials.credentials != "secret": if credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials config.api_key = credentials.credentials
@ -451,7 +457,7 @@ class Api:
"model": AppConfig.model, "model": AppConfig.model,
"provider": AppConfig.provider, "provider": AppConfig.provider,
"proxy": AppConfig.proxy, "proxy": AppConfig.proxy,
**config.dict(exclude_none=True), **(config.model_dump(exclude_none=True) if hasattr(config, "model_dump") else config.dict(exclude_none=True)),
**{ **{
"conversation_id": None, "conversation_id": None,
"conversation": conversation, "conversation": conversation,
@ -474,7 +480,7 @@ class Api:
self.conversations[config.conversation_id] = {} self.conversations[config.conversation_id] = {}
self.conversations[config.conversation_id][config.provider] = chunk self.conversations[config.conversation_id][config.provider] = chunk
else: else:
yield f"data: {chunk.json()}\n\n" yield f"data: {chunk.model_dump_json() if hasattr(chunk, 'model_dump_json') else chunk.json()}\n\n"
except GeneratorExit: except GeneratorExit:
pass pass
except Exception as e: except Exception as e:

View file

@ -22,6 +22,7 @@ class RequestConfig(BaseModel):
proxy: Optional[str] = None proxy: Optional[str] = None
conversation: Optional[dict] = None conversation: Optional[dict] = None
timeout: Optional[int] = None timeout: Optional[int] = None
stream_timeout: Optional[int] = None
tool_calls: list = Field(default=[], examples=[[ tool_calls: list = Field(default=[], examples=[[
{ {
"function": { "function": {

View file

@ -8,12 +8,13 @@ try:
except ImportError as e: except ImportError as e:
import_error = e import_error = e
def get_gui_app(demo: bool = False, timeout: int = None): def get_gui_app(demo: bool = False, timeout: int = None, stream_timeout: int = None):
if import_error is not None: if import_error is not None:
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}') raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
app = create_app() app = create_app()
app.demo = demo app.demo = demo
app.timeout = timeout app.timeout = timeout
app.stream_timeout = stream_timeout
site = Website(app) site = Website(app)
for route in site.routes: for route in site.routes:

View file

@ -196,6 +196,8 @@ class Backend_Api(Api):
json_data['media'] = media json_data['media'] = media
if app.timeout: if app.timeout:
json_data['timeout'] = app.timeout json_data['timeout'] = app.timeout
if app.stream_timeout:
json_data['stream_timeout'] = app.stream_timeout
if app.demo and not json_data.get("provider"): if app.demo and not json_data.get("provider"):
model = json_data.get("model") model = json_data.get("model")
if model != "default" and model in models.demo_models: if model != "default" and model in models.demo_models:

View file

@ -14,6 +14,7 @@ from .Provider import (
HuggingSpace, HuggingSpace,
Grok, Grok,
DeepseekAI_JanusPro7b, DeepseekAI_JanusPro7b,
GLM,
Kimi, Kimi,
LambdaChat, LambdaChat,
Mintlify, Mintlify,
@ -25,6 +26,7 @@ from .Provider import (
PerplexityLabs, PerplexityLabs,
PollinationsAI, PollinationsAI,
PollinationsImage, PollinationsImage,
Qwen,
TeachAnything, TeachAnything,
Together, Together,
WeWordle, WeWordle,
@ -159,8 +161,10 @@ default = Model(
DeepInfra, DeepInfra,
OperaAria, OperaAria,
Startnest, Startnest,
LambdaChat, GLM,
Kimi,
PollinationsAI, PollinationsAI,
Qwen,
Together, Together,
Chatai, Chatai,
WeWordle, WeWordle,

File diff suppressed because one or more lines are too long

View file

@ -220,6 +220,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
cls.video_models.append("video") cls.video_models.append("video")
cls.model_map["video"] = {"Video": "video"} cls.model_map["video"] = {"Video": "video"}
del cls.model_map[""]
cls.audio_models = [*cls.audio_models] cls.audio_models = [*cls.audio_models]
# Create a mapping of parent providers to their children # Create a mapping of parent providers to their children
@ -415,7 +416,7 @@ def clean_name(name: str) -> str:
name = re.sub(r'-\d{2}-\d{2}', '', name) name = re.sub(r'-\d{2}-\d{2}', '', name)
name = re.sub(r'-[0-9a-f]{8}$', '', name) name = re.sub(r'-[0-9a-f]{8}$', '', name)
# Version patterns # Version patterns
name = re.sub(r'-(instruct|chat|preview|experimental|v\d+|fp8|bf16|hf|free|tput)$', '', name) name = re.sub(r'-(instruct|preview|experimental|v\d+|fp8|bf16|hf|free|tput)$', '', name)
# Other replacements # Other replacements
name = name.replace("_", ".") name = name.replace("_", ".")
name = name.replace("c4ai-", "") name = name.replace("c4ai-", "")

View file

@ -291,6 +291,7 @@ class AsyncGeneratorProvider(AbstractProvider):
model: str, model: str,
messages: Messages, messages: Messages,
timeout: int = None, timeout: int = None,
stream_timeout: int = None,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
""" """
@ -308,7 +309,7 @@ class AsyncGeneratorProvider(AbstractProvider):
""" """
return to_sync_generator( return to_sync_generator(
cls.create_async_generator(model, messages, **kwargs), cls.create_async_generator(model, messages, **kwargs),
timeout=timeout timeout=timeout if stream_timeout is None else stream_timeout,
) )
@staticmethod @staticmethod
@ -482,7 +483,7 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
auth_result = chunk auth_result = chunk
else: else:
yield chunk yield chunk
for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)): for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs), kwargs.get("stream_timeout", kwargs.get("timeout"))):
if cache_file is not None: if cache_file is not None:
cls.write_cache_file(cache_file, auth_result) cls.write_cache_file(cache_file, auth_result)
cache_file = None cache_file = None
@ -500,6 +501,13 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
try: try:
auth_result = cls.get_auth_result() auth_result = cls.get_auth_result()
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
if "stream_timeout" in kwargs:
while True:
try:
yield await asyncio.wait_for(response.__anext__(), timeout=kwargs["stream_timeout"])
except StopAsyncIteration:
break
else:
async for chunk in response: async for chunk in response:
yield chunk yield chunk
except (MissingAuthError, NoValidHarFileError, CloudflareError): except (MissingAuthError, NoValidHarFileError, CloudflareError):