mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Fix LMAreana provider
This commit is contained in:
parent
5bacb669b2
commit
213e04bae7
5 changed files with 103 additions and 1054 deletions
|
|
@ -36,7 +36,7 @@ def clean_name(name: str) -> str:
|
||||||
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
label = "Cloudflare AI"
|
label = "Cloudflare AI"
|
||||||
url = "https://playground.ai.cloudflare.com"
|
url = "https://playground.ai.cloudflare.com"
|
||||||
working = has_curl_cffi
|
working = False
|
||||||
use_nodriver = True
|
use_nodriver = True
|
||||||
active_by_default = True
|
active_by_default = True
|
||||||
api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
|
api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import secrets
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import curl_cffi
|
import curl_cffi
|
||||||
|
|
@ -26,11 +27,26 @@ from ...requests import StreamSession, get_args_from_nodriver, raise_for_status,
|
||||||
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
|
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
|
||||||
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning, PlainTextResponse, JsonRequest
|
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning, PlainTextResponse, JsonRequest
|
||||||
from ...tools.media import merge_media
|
from ...tools.media import merge_media
|
||||||
from ...integration import uuid
|
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
|
||||||
from ..helper import get_last_user_message
|
from ..helper import get_last_user_message
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
|
def uuid7():
|
||||||
|
"""
|
||||||
|
Generate a UUIDv7 using Unix epoch (milliseconds since 1970-01-01)
|
||||||
|
matching the browser's implementation.
|
||||||
|
"""
|
||||||
|
timestamp_ms = int(time.time() * 1000)
|
||||||
|
rand_a = secrets.randbits(12)
|
||||||
|
rand_b = secrets.randbits(62)
|
||||||
|
|
||||||
|
uuid_int = timestamp_ms << 80
|
||||||
|
uuid_int |= (0x7000 | rand_a) << 64
|
||||||
|
uuid_int |= (0x8000000000000000 | rand_b)
|
||||||
|
|
||||||
|
hex_str = f"{uuid_int:032x}"
|
||||||
|
return f"{hex_str[0:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
{'id': '812c93cc-5f88-4cff-b9ca-c11a26599b0e', 'publicName': 'qwen3-max-preview',
|
{'id': '812c93cc-5f88-4cff-b9ca-c11a26599b0e', 'publicName': 'qwen3-max-preview',
|
||||||
'capabilities': {'inputCapabilities': {'text': True}, 'outputCapabilities': {'text': True}},
|
'capabilities': {'inputCapabilities': {'text': True}, 'outputCapabilities': {'text': True}},
|
||||||
|
|
@ -485,7 +501,8 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
label = "LMArena"
|
label = "LMArena"
|
||||||
url = "https://lmarena.ai"
|
url = "https://lmarena.ai"
|
||||||
share_url = None
|
share_url = None
|
||||||
api_endpoint = "https://lmarena.ai/nextjs-api/stream/create-evaluation"
|
create_evaluation = "https://lmarena.ai/nextjs-api/stream/create-evaluation"
|
||||||
|
post_to_evaluation = "https://lmarena.ai/nextjs-api/stream/post-to-evaluation/{id}"
|
||||||
working = True
|
working = True
|
||||||
active_by_default = True
|
active_by_default = True
|
||||||
use_stream_timeout = False
|
use_stream_timeout = False
|
||||||
|
|
@ -637,19 +654,21 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
else:
|
else:
|
||||||
raise ModelNotFoundError(f"Model '{model}' is not supported by LMArena provider.")
|
raise ModelNotFoundError(f"Model '{model}' is not supported by LMArena provider.")
|
||||||
|
|
||||||
evaluationSessionId = str(uuid.uuid7())
|
if conversation and getattr(conversation, "evaluationSessionId", None):
|
||||||
userMessageId = str(uuid.uuid7())
|
url = cls.post_to_evaluation.format(id=conversation.evaluationSessionId)
|
||||||
modelAMessageId = str(uuid.uuid7())
|
evaluationSessionId = conversation.evaluationSessionId
|
||||||
|
else:
|
||||||
|
url = cls.create_evaluation
|
||||||
|
evaluationSessionId = str(uuid7())
|
||||||
|
userMessageId = str(uuid7())
|
||||||
|
modelAMessageId = str(uuid7())
|
||||||
data = {
|
data = {
|
||||||
"id": evaluationSessionId,
|
"id": evaluationSessionId,
|
||||||
"mode": "direct",
|
"mode": "direct",
|
||||||
"modelAId": model_id,
|
"modelAId": model_id,
|
||||||
"userMessageId": userMessageId,
|
"userMessageId": userMessageId,
|
||||||
"modelAMessageId": modelAMessageId,
|
"modelAMessageId": modelAMessageId,
|
||||||
"messages": [
|
"userMessage": {
|
||||||
{
|
|
||||||
"id": userMessageId,
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
"content": prompt,
|
||||||
"experimental_attachments": [
|
"experimental_attachments": [
|
||||||
{
|
{
|
||||||
|
|
@ -660,33 +679,14 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
for url, name in list(merge_media(media, messages))
|
for url, name in list(merge_media(media, messages))
|
||||||
if isinstance(url, str) and url.startswith("https://")
|
if isinstance(url, str) and url.startswith("https://")
|
||||||
],
|
],
|
||||||
"parentMessageIds": [] if conversation is None else conversation.message_ids,
|
|
||||||
"participantPosition": "a",
|
|
||||||
"modelId": None,
|
|
||||||
"evaluationSessionId": evaluationSessionId,
|
|
||||||
"status": "pending",
|
|
||||||
"failureReason": None
|
|
||||||
},
|
},
|
||||||
{
|
"modality": "image" if is_image_model else "chat",
|
||||||
"id": modelAMessageId,
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "",
|
|
||||||
"experimental_attachments": [],
|
|
||||||
"parentMessageIds": [userMessageId],
|
|
||||||
"participantPosition": "a",
|
|
||||||
"modelId": model,
|
|
||||||
"evaluationSessionId": evaluationSessionId,
|
|
||||||
"status": "pending",
|
|
||||||
"failureReason": None
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"modality": "image" if is_image_model else "chat"
|
|
||||||
}
|
}
|
||||||
yield JsonRequest.from_dict(data)
|
yield JsonRequest.from_dict(data)
|
||||||
try:
|
try:
|
||||||
async with StreamSession(**args, timeout=timeout) as session:
|
async with StreamSession(**args, timeout=timeout) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
cls.api_endpoint,
|
url,
|
||||||
json=data,
|
json=data,
|
||||||
proxy=proxy
|
proxy=proxy
|
||||||
) as response:
|
) as response:
|
||||||
|
|
@ -695,9 +695,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
async for chunk in response.iter_lines():
|
async for chunk in response.iter_lines():
|
||||||
line = chunk.decode()
|
line = chunk.decode()
|
||||||
yield PlainTextResponse(line)
|
yield PlainTextResponse(line)
|
||||||
if line.startswith("af:"):
|
if line.startswith("a0:"):
|
||||||
yield JsonConversation(message_ids=[modelAMessageId])
|
|
||||||
elif line.startswith("a0:"):
|
|
||||||
chunk = json.loads(line[3:])
|
chunk = json.loads(line[3:])
|
||||||
if chunk == "hasArenaError":
|
if chunk == "hasArenaError":
|
||||||
raise ModelNotFoundError("LMArena Beta encountered an error: hasArenaError")
|
raise ModelNotFoundError("LMArena Beta encountered an error: hasArenaError")
|
||||||
|
|
@ -708,6 +706,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
elif line.startswith("a2:"):
|
elif line.startswith("a2:"):
|
||||||
yield ImageResponse([image.get("image") for image in json.loads(line[3:])], prompt)
|
yield ImageResponse([image.get("image") for image in json.loads(line[3:])], prompt)
|
||||||
elif line.startswith("ad:"):
|
elif line.startswith("ad:"):
|
||||||
|
yield JsonConversation(evaluationSessionId=evaluationSessionId)
|
||||||
finish = json.loads(line[3:])
|
finish = json.loads(line[3:])
|
||||||
if "finishReason" in finish:
|
if "finishReason" in finish:
|
||||||
yield FinishReason(finish["finishReason"])
|
yield FinishReason(finish["finishReason"])
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,10 @@ import uuid
|
||||||
import random
|
import random
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
try:
|
||||||
from .crypt import decrypt, encrypt
|
from .crypt import decrypt, encrypt
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
from ...requests import StreamSession
|
from ...requests import StreamSession
|
||||||
from ...cookies import get_cookies_dir
|
from ...cookies import get_cookies_dir
|
||||||
from ...errors import NoValidHarFileError
|
from ...errors import NoValidHarFileError
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,18 @@ from typing import Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from pydantic_ai.models import Model, KnownModelName, infer_model
|
from pydantic_ai import ModelResponsePart, ThinkingPart, ToolCallPart
|
||||||
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
|
from pydantic_ai.models import Model, ModelResponse, KnownModelName, infer_model
|
||||||
|
from pydantic_ai.models.openai import OpenAIChatModel, UnexpectedModelBehavior
|
||||||
|
from pydantic_ai.models.openai import OpenAISystemPromptRole, _CHAT_FINISH_REASON_MAP, _map_usage, _now_utc, number_to_datetime, split_content_into_text_and_thinking, replace
|
||||||
|
|
||||||
import pydantic_ai.models.openai
|
import pydantic_ai.models.openai
|
||||||
pydantic_ai.models.openai.NOT_GIVEN = None
|
pydantic_ai.models.openai.NOT_GIVEN = None
|
||||||
|
|
||||||
from ..client import AsyncClient
|
from ..client import AsyncClient, ChatCompletion
|
||||||
|
|
||||||
@dataclass(init=False)
|
@dataclass(init=False)
|
||||||
class AIModel(OpenAIModel):
|
class AIModel(OpenAIChatModel):
|
||||||
"""A model that uses the G4F API."""
|
"""A model that uses the G4F API."""
|
||||||
|
|
||||||
client: AsyncClient = field(repr=False)
|
client: AsyncClient = field(repr=False)
|
||||||
|
|
@ -54,6 +56,61 @@ class AIModel(OpenAIModel):
|
||||||
return f'g4f:{self._provider}:{self._model_name}'
|
return f'g4f:{self._provider}:{self._model_name}'
|
||||||
return f'g4f:{self._model_name}'
|
return f'g4f:{self._model_name}'
|
||||||
|
|
||||||
|
def _process_response(self, response: ChatCompletion | str) -> ModelResponse:
|
||||||
|
"""Process a non-streamed response, and prepare a message to return."""
|
||||||
|
# Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function:
|
||||||
|
# * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!)
|
||||||
|
# * if the endpoint returns plain text, the return type is a string
|
||||||
|
# Thus we validate it fully here.
|
||||||
|
if not isinstance(response, ChatCompletion):
|
||||||
|
raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
|
||||||
|
|
||||||
|
if response.created:
|
||||||
|
timestamp = number_to_datetime(response.created)
|
||||||
|
else:
|
||||||
|
timestamp = _now_utc()
|
||||||
|
response.created = int(timestamp.timestamp())
|
||||||
|
|
||||||
|
# Workaround for local Ollama which sometimes returns a `None` finish reason.
|
||||||
|
if response.choices and (choice := response.choices[0]) and choice.finish_reason is None: # pyright: ignore[reportUnnecessaryComparison]
|
||||||
|
choice.finish_reason = 'stop'
|
||||||
|
|
||||||
|
choice = response.choices[0]
|
||||||
|
items: list[ModelResponsePart] = []
|
||||||
|
|
||||||
|
# The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter.
|
||||||
|
# - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api
|
||||||
|
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens
|
||||||
|
if reasoning := getattr(choice.message, 'reasoning', None):
|
||||||
|
items.append(ThinkingPart(id='reasoning', content=reasoning, provider_name=self.system))
|
||||||
|
|
||||||
|
# NOTE: We don't currently handle OpenRouter `reasoning_details`:
|
||||||
|
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
|
||||||
|
# If you need this, please file an issue.
|
||||||
|
|
||||||
|
if choice.message.content:
|
||||||
|
items.extend(
|
||||||
|
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
|
||||||
|
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
|
||||||
|
)
|
||||||
|
if choice.message.tool_calls is not None:
|
||||||
|
for c in choice.message.tool_calls:
|
||||||
|
items.append(ToolCallPart(c.get("function").get("name"), c.get("function").get("arguments"), tool_call_id=c.get("id")))
|
||||||
|
|
||||||
|
raw_finish_reason = choice.finish_reason
|
||||||
|
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
parts=items,
|
||||||
|
usage=_map_usage(response, self._provider, "", self._model_name),
|
||||||
|
model_name=response.model,
|
||||||
|
timestamp=timestamp,
|
||||||
|
provider_details=None,
|
||||||
|
provider_response_id=response.id,
|
||||||
|
provider_name=self._provider,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
|
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
|
||||||
if isinstance(model, Model):
|
if isinstance(model, Model):
|
||||||
return model
|
return model
|
||||||
|
|
@ -69,4 +126,4 @@ def patch_infer_model(api_key: str | None = None):
|
||||||
import pydantic_ai.models
|
import pydantic_ai.models
|
||||||
|
|
||||||
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
|
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
|
||||||
pydantic_ai.models.AIModel = AIModel
|
pydantic_ai.models.OpenAIChatModel = AIModel
|
||||||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue