Fix LMAreana provider

This commit is contained in:
hlohaus 2025-11-10 09:30:53 +01:00
parent 5bacb669b2
commit 213e04bae7
5 changed files with 103 additions and 1054 deletions

View file

@ -36,7 +36,7 @@ def clean_name(name: str) -> str:
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin): class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
label = "Cloudflare AI" label = "Cloudflare AI"
url = "https://playground.ai.cloudflare.com" url = "https://playground.ai.cloudflare.com"
working = has_curl_cffi working = False
use_nodriver = True use_nodriver = True
active_by_default = True active_by_default = True
api_endpoint = "https://playground.ai.cloudflare.com/api/inference" api_endpoint = "https://playground.ai.cloudflare.com/api/inference"

View file

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
import uuid
import json import json
import asyncio import asyncio
import os import os
import requests import requests
import json import json
import time
import secrets
try: try:
import curl_cffi import curl_cffi
@ -26,11 +27,26 @@ from ...requests import StreamSession, get_args_from_nodriver, raise_for_status,
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning, PlainTextResponse, JsonRequest from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning, PlainTextResponse, JsonRequest
from ...tools.media import merge_media from ...tools.media import merge_media
from ...integration import uuid
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
from ..helper import get_last_user_message from ..helper import get_last_user_message
from ... import debug from ... import debug
def uuid7():
"""
Generate a UUIDv7 using Unix epoch (milliseconds since 1970-01-01)
matching the browser's implementation.
"""
timestamp_ms = int(time.time() * 1000)
rand_a = secrets.randbits(12)
rand_b = secrets.randbits(62)
uuid_int = timestamp_ms << 80
uuid_int |= (0x7000 | rand_a) << 64
uuid_int |= (0x8000000000000000 | rand_b)
hex_str = f"{uuid_int:032x}"
return f"{hex_str[0:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
models = [ models = [
{'id': '812c93cc-5f88-4cff-b9ca-c11a26599b0e', 'publicName': 'qwen3-max-preview', {'id': '812c93cc-5f88-4cff-b9ca-c11a26599b0e', 'publicName': 'qwen3-max-preview',
'capabilities': {'inputCapabilities': {'text': True}, 'outputCapabilities': {'text': True}}, 'capabilities': {'inputCapabilities': {'text': True}, 'outputCapabilities': {'text': True}},
@ -485,7 +501,8 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
label = "LMArena" label = "LMArena"
url = "https://lmarena.ai" url = "https://lmarena.ai"
share_url = None share_url = None
api_endpoint = "https://lmarena.ai/nextjs-api/stream/create-evaluation" create_evaluation = "https://lmarena.ai/nextjs-api/stream/create-evaluation"
post_to_evaluation = "https://lmarena.ai/nextjs-api/stream/post-to-evaluation/{id}"
working = True working = True
active_by_default = True active_by_default = True
use_stream_timeout = False use_stream_timeout = False
@ -637,21 +654,23 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
else: else:
raise ModelNotFoundError(f"Model '{model}' is not supported by LMArena provider.") raise ModelNotFoundError(f"Model '{model}' is not supported by LMArena provider.")
evaluationSessionId = str(uuid.uuid7()) if conversation and getattr(conversation, "evaluationSessionId", None):
userMessageId = str(uuid.uuid7()) url = cls.post_to_evaluation.format(id=conversation.evaluationSessionId)
modelAMessageId = str(uuid.uuid7()) evaluationSessionId = conversation.evaluationSessionId
else:
url = cls.create_evaluation
evaluationSessionId = str(uuid7())
userMessageId = str(uuid7())
modelAMessageId = str(uuid7())
data = { data = {
"id": evaluationSessionId, "id": evaluationSessionId,
"mode": "direct", "mode": "direct",
"modelAId": model_id, "modelAId": model_id,
"userMessageId": userMessageId, "userMessageId": userMessageId,
"modelAMessageId": modelAMessageId, "modelAMessageId": modelAMessageId,
"messages": [ "userMessage": {
{ "content": prompt,
"id": userMessageId, "experimental_attachments": [
"role": "user",
"content": prompt,
"experimental_attachments": [
{ {
"name": name or os.path.basename(url), "name": name or os.path.basename(url),
"contentType": get_content_type(url), "contentType": get_content_type(url),
@ -660,33 +679,14 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
for url, name in list(merge_media(media, messages)) for url, name in list(merge_media(media, messages))
if isinstance(url, str) and url.startswith("https://") if isinstance(url, str) and url.startswith("https://")
], ],
"parentMessageIds": [] if conversation is None else conversation.message_ids, },
"participantPosition": "a", "modality": "image" if is_image_model else "chat",
"modelId": None,
"evaluationSessionId": evaluationSessionId,
"status": "pending",
"failureReason": None
},
{
"id": modelAMessageId,
"role": "assistant",
"content": "",
"experimental_attachments": [],
"parentMessageIds": [userMessageId],
"participantPosition": "a",
"modelId": model,
"evaluationSessionId": evaluationSessionId,
"status": "pending",
"failureReason": None
}
],
"modality": "image" if is_image_model else "chat"
} }
yield JsonRequest.from_dict(data) yield JsonRequest.from_dict(data)
try: try:
async with StreamSession(**args, timeout=timeout) as session: async with StreamSession(**args, timeout=timeout) as session:
async with session.post( async with session.post(
cls.api_endpoint, url,
json=data, json=data,
proxy=proxy proxy=proxy
) as response: ) as response:
@ -695,9 +695,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
async for chunk in response.iter_lines(): async for chunk in response.iter_lines():
line = chunk.decode() line = chunk.decode()
yield PlainTextResponse(line) yield PlainTextResponse(line)
if line.startswith("af:"): if line.startswith("a0:"):
yield JsonConversation(message_ids=[modelAMessageId])
elif line.startswith("a0:"):
chunk = json.loads(line[3:]) chunk = json.loads(line[3:])
if chunk == "hasArenaError": if chunk == "hasArenaError":
raise ModelNotFoundError("LMArena Beta encountered an error: hasArenaError") raise ModelNotFoundError("LMArena Beta encountered an error: hasArenaError")
@ -708,6 +706,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
elif line.startswith("a2:"): elif line.startswith("a2:"):
yield ImageResponse([image.get("image") for image in json.loads(line[3:])], prompt) yield ImageResponse([image.get("image") for image in json.loads(line[3:])], prompt)
elif line.startswith("ad:"): elif line.startswith("ad:"):
yield JsonConversation(evaluationSessionId=evaluationSessionId)
finish = json.loads(line[3:]) finish = json.loads(line[3:])
if "finishReason" in finish: if "finishReason" in finish:
yield FinishReason(finish["finishReason"]) yield FinishReason(finish["finishReason"])

View file

@ -9,8 +9,10 @@ import uuid
import random import random
from urllib.parse import unquote from urllib.parse import unquote
from copy import deepcopy from copy import deepcopy
try:
from .crypt import decrypt, encrypt from .crypt import decrypt, encrypt
except ImportError:
pass
from ...requests import StreamSession from ...requests import StreamSession
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ...errors import NoValidHarFileError from ...errors import NoValidHarFileError

View file

@ -4,16 +4,18 @@ from typing import Optional
from functools import partial from functools import partial
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pydantic_ai.models import Model, KnownModelName, infer_model from pydantic_ai import ModelResponsePart, ThinkingPart, ToolCallPart
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole from pydantic_ai.models import Model, ModelResponse, KnownModelName, infer_model
from pydantic_ai.models.openai import OpenAIChatModel, UnexpectedModelBehavior
from pydantic_ai.models.openai import OpenAISystemPromptRole, _CHAT_FINISH_REASON_MAP, _map_usage, _now_utc, number_to_datetime, split_content_into_text_and_thinking, replace
import pydantic_ai.models.openai import pydantic_ai.models.openai
pydantic_ai.models.openai.NOT_GIVEN = None pydantic_ai.models.openai.NOT_GIVEN = None
from ..client import AsyncClient from ..client import AsyncClient, ChatCompletion
@dataclass(init=False) @dataclass(init=False)
class AIModel(OpenAIModel): class AIModel(OpenAIChatModel):
"""A model that uses the G4F API.""" """A model that uses the G4F API."""
client: AsyncClient = field(repr=False) client: AsyncClient = field(repr=False)
@ -53,6 +55,61 @@ class AIModel(OpenAIModel):
if self._provider: if self._provider:
return f'g4f:{self._provider}:{self._model_name}' return f'g4f:{self._provider}:{self._model_name}'
return f'g4f:{self._model_name}' return f'g4f:{self._model_name}'
def _process_response(self, response: ChatCompletion | str) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
# Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function:
# * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!)
# * if the endpoint returns plain text, the return type is a string
# Thus we validate it fully here.
if not isinstance(response, ChatCompletion):
raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
if response.created:
timestamp = number_to_datetime(response.created)
else:
timestamp = _now_utc()
response.created = int(timestamp.timestamp())
# Workaround for local Ollama which sometimes returns a `None` finish reason.
if response.choices and (choice := response.choices[0]) and choice.finish_reason is None: # pyright: ignore[reportUnnecessaryComparison]
choice.finish_reason = 'stop'
choice = response.choices[0]
items: list[ModelResponsePart] = []
# The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter.
# - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens
if reasoning := getattr(choice.message, 'reasoning', None):
items.append(ThinkingPart(id='reasoning', content=reasoning, provider_name=self.system))
# NOTE: We don't currently handle OpenRouter `reasoning_details`:
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
# If you need this, please file an issue.
if choice.message.content:
items.extend(
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
)
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart(c.get("function").get("name"), c.get("function").get("arguments"), tool_call_id=c.get("id")))
raw_finish_reason = choice.finish_reason
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
return ModelResponse(
parts=items,
usage=_map_usage(response, self._provider, "", self._model_name),
model_name=response.model,
timestamp=timestamp,
provider_details=None,
provider_response_id=response.id,
provider_name=self._provider,
finish_reason=finish_reason,
)
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model: def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
if isinstance(model, Model): if isinstance(model, Model):
@ -69,4 +126,4 @@ def patch_infer_model(api_key: str | None = None):
import pydantic_ai.models import pydantic_ai.models
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key) pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
pydantic_ai.models.AIModel = AIModel pydantic_ai.models.OpenAIChatModel = AIModel

File diff suppressed because it is too large Load diff