fix: improve session handling, message formatting, and content saving

- Added `timeout` parameter support to `LMArenaBeta._create_async_generator` and passed it to `StreamSession`
- Ensured fallback to `default_model` in `LMArenaBeta` if `model` is not provided
- Modified `OpenaiChat._create_completion` to rebuild messages excluding assistant roles if conversation ID exists
- Corrected OpenaiChat `nodriver_auth` to await `browser.get` and replaced page access with reload
- Improved `save_content` in `client.py` with robust content extraction, null checks, and logging for missing content
- Removed premature `input_text.strip()` in `stream_response` and relocated it to `run_client_args`
- Simplified and centralized markdown filtering call in `save_content`
- Replaced raw `print` logging in `__init__.py` with `debug.log` for `nodriver` URL opening message
This commit is contained in:
hlohaus 2025-06-29 21:59:22 +02:00
parent c1c00eee04
commit 5e4b9d9866
4 changed files with 25 additions and 12 deletions

View file

@ -92,6 +92,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
messages: Messages,
conversation: JsonConversation = None,
proxy: str = None,
timeout: int = None,
**kwargs
) -> AsyncResult:
cache_file = cls.get_cache_file()
@ -114,6 +115,8 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
# Build the JSON payload
is_image_model = model in image_models
if not model:
model = cls.default_model
if model in image_models:
model = image_models[model]
elif model in text_models:
@ -158,7 +161,7 @@ class LMArenaBeta(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
],
"modality": "image" if is_image_model else "chat"
}
async with StreamSession(**args) as session:
async with StreamSession(**args, timeout=timeout) as session:
async with session.post(
cls.api_endpoint,
json=data,

View file

@ -431,8 +431,14 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if action != "continue":
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
conversation.parent_message_id = None
messages = messages if conversation.conversation_id is None else [{"role": "user", "content": prompt}]
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
new_messages = messages
if conversation.conversation_id is not None:
for message in messages:
if message.get("role") == "assistant":
new_messages = []
else:
new_messages.append(message)
data["messages"] = cls.create_messages(new_messages, image_requests, ["search"] if web_search else None)
headers = {
**cls._headers,
"accept": "text/event-stream",
@ -655,7 +661,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
async def nodriver_auth(cls, proxy: str = None):
browser, stop_browser = await get_nodriver(proxy=proxy)
try:
page = browser.main_tab
page = await browser.get(cls.url)
def on_request(event: nodriver.cdp.network.RequestWillBeSent, page=None):
if event.request.url == start_url or event.request.url.startswith(conversation_url):
if cls.request_config.headers is None:
@ -681,7 +687,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
)
await page.send(nodriver.cdp.network.enable())
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
page = await browser.get(cls.url)
await page.reload()
user_agent = await page.evaluate("window.navigator.userAgent", return_by_value=True)
textarea = None
while not textarea:

View file

@ -102,7 +102,6 @@ async def stream_response(
image = None
if isinstance(input_text, tuple):
image, input_text = input_text
input_text = input_text.strip()
if instructions:
# Add system instructions to conversation if provided
@ -143,7 +142,7 @@ async def stream_response(
if output_file:
if save_content(response_content, output_file):
print(f"\nResponse saved to {output_file}")
if response_content:
# Add assistant message to conversation
conversation.add_message("assistant", str(response_content))
@ -152,9 +151,12 @@ async def stream_response(
def save_content(content, filepath: str, allowed_types = None):
if hasattr(content, "urls"):
content = content.urls[0] if isinstance(content.urls, list) else content.urls
content = next(iter(content.urls), None) if isinstance(content.urls, list) else content.urls
elif hasattr(content, "data"):
content = content.data
if not content:
print("\nNo content to save.", file=sys.stderr)
return False
if content.startswith("/media/"):
os.rename(content.replace("/media", get_media_dir()).split("?")[0], filepath)
return True
@ -169,11 +171,14 @@ def save_content(content, filepath: str, allowed_types = None):
with open(filepath, "wb") as f:
f.write(response.content)
return True
content = filter_markdown(content, allowed_types, content)
content = filter_markdown(content, allowed_types)
if content:
with open(filepath, "w") as f:
f.write(content)
return True
else:
print("\nNo valid content to save.", file=sys.stderr)
return False
def get_parser():
"""Parse command line arguments."""
@ -278,7 +283,7 @@ def run_client_args(args):
input_text = " ".join(args.input[1:]) + "\n"
input_text += f"```{os.path.basename(args.input[0])}\n" + file_content + "\n```"
elif args.input:
input_text = " ".join(args.input)
input_text = (" ".join(args.input)).strip()
if not input_text:
input_text = sys.stdin.read().strip()
if not input_text:

View file

@ -100,8 +100,7 @@ async def get_args_from_nodriver(
def stop_browser():
...
try:
if debug.logging:
print(f"Open nodriver with url: {url}")
debug.log(f"Open nodriver with url: {url}")
domain = urlparse(url).netloc
if cookies is None:
cookies = {}