mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
fix: improve session handling, message formatting, and content saving
- Added `timeout` parameter support to `LMArenaBeta._create_async_generator` and passed it to `StreamSession` - Ensured fallback to `default_model` in `LMArenaBeta` if `model` is not provided - Modified `OpenaiChat._create_completion` to rebuild messages excluding assistant roles if conversation ID exists - Corrected OpenaiChat `nodriver_auth` to await `browser.get` and replaced page access with reload - Improved `save_content` in `client.py` with robust content extraction, null checks, and logging for missing content - Removed premature `input_text.strip()` in `stream_response` and relocated it to `run_client_args` - Simplified and centralized markdown filtering call in `save_content` - Replaced raw `print` logging in `__init__.py` with `debug.log` for `nodriver` URL opening message
This commit is contained in:
parent
c1c00eee04
commit
5e4b9d9866
4 changed files with 25 additions and 12 deletions
|
|
@ -92,6 +92,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
messages: Messages,
|
||||
conversation: JsonConversation = None,
|
||||
proxy: str = None,
|
||||
timeout: int = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
cache_file = cls.get_cache_file()
|
||||
|
|
@ -114,6 +115,8 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
|
||||
# Build the JSON payload
|
||||
is_image_model = model in image_models
|
||||
if not model:
|
||||
model = cls.default_model
|
||||
if model in image_models:
|
||||
model = image_models[model]
|
||||
elif model in text_models:
|
||||
|
|
@ -158,7 +161,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
],
|
||||
"modality": "image" if is_image_model else "chat"
|
||||
}
|
||||
async with StreamSession(**args) as session:
|
||||
async with StreamSession(**args, timeout=timeout) as session:
|
||||
async with session.post(
|
||||
cls.api_endpoint,
|
||||
json=data,
|
||||
|
|
|
|||
|
|
@ -431,8 +431,14 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
if action != "continue":
|
||||
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
|
||||
conversation.parent_message_id = None
|
||||
messages = messages if conversation.conversation_id is None else [{"role": "user", "content": prompt}]
|
||||
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
|
||||
new_messages = messages
|
||||
if conversation.conversation_id is not None:
|
||||
for message in messages:
|
||||
if message.get("role") == "assistant":
|
||||
new_messages = []
|
||||
else:
|
||||
new_messages.append(message)
|
||||
data["messages"] = cls.create_messages(new_messages, image_requests, ["search"] if web_search else None)
|
||||
headers = {
|
||||
**cls._headers,
|
||||
"accept": "text/event-stream",
|
||||
|
|
@ -655,7 +661,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
async def nodriver_auth(cls, proxy: str = None):
|
||||
browser, stop_browser = await get_nodriver(proxy=proxy)
|
||||
try:
|
||||
page = browser.main_tab
|
||||
page = await browser.get(cls.url)
|
||||
def on_request(event: nodriver.cdp.network.RequestWillBeSent, page=None):
|
||||
if event.request.url == start_url or event.request.url.startswith(conversation_url):
|
||||
if cls.request_config.headers is None:
|
||||
|
|
@ -681,7 +687,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
)
|
||||
await page.send(nodriver.cdp.network.enable())
|
||||
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
|
||||
page = await browser.get(cls.url)
|
||||
await page.reload()
|
||||
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
|
||||
textarea = None
|
||||
while not textarea:
|
||||
|
|
|
|||
|
|
@ -102,7 +102,6 @@ async def stream_response(
|
|||
image = None
|
||||
if isinstance(input_text, tuple):
|
||||
image, input_text = input_text
|
||||
input_text = input_text.strip()
|
||||
|
||||
if instructions:
|
||||
# Add system instructions to conversation if provided
|
||||
|
|
@ -143,7 +142,7 @@ async def stream_response(
|
|||
if output_file:
|
||||
if save_content(response_content, output_file):
|
||||
print(f"\nResponse saved to {output_file}")
|
||||
|
||||
|
||||
if response_content:
|
||||
# Add assistant message to conversation
|
||||
conversation.add_message("assistant", str(response_content))
|
||||
|
|
@ -152,9 +151,12 @@ async def stream_response(
|
|||
|
||||
def save_content(content, filepath: str, allowed_types = None):
|
||||
if hasattr(content, "urls"):
|
||||
content = content.urls[0] if isinstance(content.urls, list) else content.urls
|
||||
content = next(iter(content.urls), None) if isinstance(content.urls, list) else content.urls
|
||||
elif hasattr(content, "data"):
|
||||
content = content.data
|
||||
if not content:
|
||||
print("\nNo content to save.", file=sys.stderr)
|
||||
return False
|
||||
if content.startswith("/media/"):
|
||||
os.rename(content.replace("/media", get_media_dir()).split("?")[0], filepath)
|
||||
return True
|
||||
|
|
@ -169,11 +171,14 @@ def save_content(content, filepath: str, allowed_types = None):
|
|||
with open(filepath, "wb") as f:
|
||||
f.write(response.content)
|
||||
return True
|
||||
content = filter_markdown(content, allowed_types, content)
|
||||
content = filter_markdown(content, allowed_types)
|
||||
if content:
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
return True
|
||||
else:
|
||||
print("\nNo valid content to save.", file=sys.stderr)
|
||||
return False
|
||||
|
||||
def get_parser():
|
||||
"""Parse command line arguments."""
|
||||
|
|
@ -278,7 +283,7 @@ def run_client_args(args):
|
|||
input_text = " ".join(args.input[1:]) + "\n"
|
||||
input_text += f"```{os.path.basename(args.input[0])}\n" + file_content + "\n```"
|
||||
elif args.input:
|
||||
input_text = " ".join(args.input)
|
||||
input_text = (" ".join(args.input)).strip()
|
||||
if not input_text:
|
||||
input_text = sys.stdin.read().strip()
|
||||
if not input_text:
|
||||
|
|
|
|||
|
|
@ -100,8 +100,7 @@ async def get_args_from_nodriver(
|
|||
def stop_browser():
|
||||
...
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Open nodriver with url: {url}")
|
||||
debug.log(f"Open nodriver with url: {url}")
|
||||
domain = urlparse(url).netloc
|
||||
if cookies is None:
|
||||
cookies = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue