Refactor PollinationsAI model selection logic; update default system in AIModel; clean up unused provider imports

This commit is contained in:
hlohaus 2025-11-10 13:58:07 +01:00
parent 213e04bae7
commit 1dac52a191
4 changed files with 16 additions and 45 deletions

View file

@ -206,7 +206,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if is_data_an_audio(media_data, filename): if is_data_an_audio(media_data, filename):
has_audio = True has_audio = True
break break
model = cls.default_audio_model if has_audio else model model = cls.default_audio_model if has_audio else cls.default_model
elif cls._models_loaded or cls.get_models(): elif cls._models_loaded or cls.get_models():
if model in cls.model_aliases: if model in cls.model_aliases:
model = cls.model_aliases[model] model = cls.model_aliases[model]

View file

@ -6,8 +6,9 @@ from dataclasses import dataclass, field
from pydantic_ai import ModelResponsePart, ThinkingPart, ToolCallPart from pydantic_ai import ModelResponsePart, ThinkingPart, ToolCallPart
from pydantic_ai.models import Model, ModelResponse, KnownModelName, infer_model from pydantic_ai.models import Model, ModelResponse, KnownModelName, infer_model
from pydantic_ai.models.openai import OpenAIChatModel, UnexpectedModelBehavior from pydantic_ai.usage import RequestUsage
from pydantic_ai.models.openai import OpenAISystemPromptRole, _CHAT_FINISH_REASON_MAP, _map_usage, _now_utc, number_to_datetime, split_content_into_text_and_thinking, replace from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.openai import OpenAISystemPromptRole, _now_utc, split_content_into_text_and_thinking, replace
import pydantic_ai.models.openai import pydantic_ai.models.openai
pydantic_ai.models.openai.NOT_GIVEN = None pydantic_ai.models.openai.NOT_GIVEN = None
@ -31,7 +32,7 @@ class AIModel(OpenAIChatModel):
provider: str | None = None, provider: str | None = None,
*, *,
system_prompt_role: OpenAISystemPromptRole | None = None, system_prompt_role: OpenAISystemPromptRole | None = None,
system: str | None = 'openai', system: str | None = 'g4f',
**kwargs **kwargs
): ):
"""Initialize an AI model. """Initialize an AI model.
@ -46,7 +47,7 @@ class AIModel(OpenAIChatModel):
customize the `base_url` and `api_key` to use a different provider. customize the `base_url` and `api_key` to use a different provider.
""" """
self._model_name = model_name self._model_name = model_name
self._provider = provider self._provider = getattr(provider, '__name__', provider)
self.client = AsyncClient(provider=provider, **kwargs) self.client = AsyncClient(provider=provider, **kwargs)
self.system_prompt_role = system_prompt_role self.system_prompt_role = system_prompt_role
self._system = system self._system = system
@ -58,36 +59,12 @@ class AIModel(OpenAIChatModel):
def _process_response(self, response: ChatCompletion | str) -> ModelResponse: def _process_response(self, response: ChatCompletion | str) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return.""" """Process a non-streamed response, and prepare a message to return."""
# Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function:
# * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!)
# * if the endpoint returns plain text, the return type is a string
# Thus we validate it fully here.
if not isinstance(response, ChatCompletion):
raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
if response.created:
timestamp = number_to_datetime(response.created)
else:
timestamp = _now_utc()
response.created = int(timestamp.timestamp())
# Workaround for local Ollama which sometimes returns a `None` finish reason.
if response.choices and (choice := response.choices[0]) and choice.finish_reason is None: # pyright: ignore[reportUnnecessaryComparison]
choice.finish_reason = 'stop'
choice = response.choices[0] choice = response.choices[0]
items: list[ModelResponsePart] = [] items: list[ModelResponsePart] = []
# The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter.
# - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens
if reasoning := getattr(choice.message, 'reasoning', None): if reasoning := getattr(choice.message, 'reasoning', None):
items.append(ThinkingPart(id='reasoning', content=reasoning, provider_name=self.system)) items.append(ThinkingPart(id='reasoning', content=reasoning, provider_name=self.system))
# NOTE: We don't currently handle OpenRouter `reasoning_details`:
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
# If you need this, please file an issue.
if choice.message.content: if choice.message.content:
items.extend( items.extend(
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part) (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
@ -95,20 +72,21 @@ class AIModel(OpenAIChatModel):
) )
if choice.message.tool_calls is not None: if choice.message.tool_calls is not None:
for c in choice.message.tool_calls: for c in choice.message.tool_calls:
items.append(ToolCallPart(c.get("function").get("name"), c.get("function").get("arguments"), tool_call_id=c.get("id"))) items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
usage = RequestUsage(
raw_finish_reason = choice.finish_reason input_tokens=response.usage.prompt_tokens,
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason) output_tokens=response.usage.completion_tokens,
)
return ModelResponse( return ModelResponse(
parts=items, parts=items,
usage=_map_usage(response, self._provider, "", self._model_name), usage=usage,
model_name=response.model, model_name=response.model,
timestamp=timestamp, timestamp=_now_utc(),
provider_details=None, provider_details=None,
provider_response_id=response.id, provider_response_id=response.id,
provider_name=self._provider, provider_name=self._provider,
finish_reason=finish_reason, finish_reason=choice.finish_reason,
) )
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model: def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
@ -125,5 +103,4 @@ def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model
def patch_infer_model(api_key: str | None = None): def patch_infer_model(api_key: str | None = None):
import pydantic_ai.models import pydantic_ai.models
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key) pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
pydantic_ai.models.OpenAIChatModel = AIModel

View file

@ -6,7 +6,6 @@ from typing import Dict, List, Optional
from .Provider import IterListProvider, ProviderType from .Provider import IterListProvider, ProviderType
from .Provider import ( from .Provider import (
### No Auth Required ### ### No Auth Required ###
Blackbox,
Chatai, Chatai,
Cloudflare, Cloudflare,
Copilot, Copilot,
@ -17,7 +16,6 @@ from .Provider import (
GLM, GLM,
Kimi, Kimi,
LambdaChat, LambdaChat,
Mintlify,
OIVSCodeSer2, OIVSCodeSer2,
OIVSCodeSer0501, OIVSCodeSer0501,
OperaAria, OperaAria,
@ -27,7 +25,6 @@ from .Provider import (
PollinationsAI, PollinationsAI,
PollinationsImage, PollinationsImage,
Qwen, Qwen,
StringableInference,
TeachAnything, TeachAnything,
Together, Together,
WeWordle, WeWordle,
@ -155,7 +152,6 @@ default = Model(
name = "", name = "",
base_provider = "", base_provider = "",
best_provider = IterListProvider([ best_provider = IterListProvider([
StringableInference,
OIVSCodeSer0501, OIVSCodeSer0501,
OIVSCodeSer2, OIVSCodeSer2,
Copilot, Copilot,
@ -168,7 +164,6 @@ default = Model(
Together, Together,
Chatai, Chatai,
WeWordle, WeWordle,
Mintlify,
TeachAnything, TeachAnything,
OpenaiChat, OpenaiChat,
Cloudflare, Cloudflare,
@ -179,7 +174,6 @@ default_vision = VisionModel(
name = "", name = "",
base_provider = "", base_provider = "",
best_provider = IterListProvider([ best_provider = IterListProvider([
StringableInference,
DeepInfra, DeepInfra,
OIVSCodeSer0501, OIVSCodeSer0501,
OIVSCodeSer2, OIVSCodeSer2,

View file

@ -348,7 +348,7 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
has_audio = True has_audio = True
break break
has_image = True has_image = True
if "tools" in kwargs: if kwargs.get("tools", None):
providers = [PollinationsAI] providers = [PollinationsAI]
elif "audio" in kwargs or "audio" in kwargs.get("modalities", []): elif "audio" in kwargs or "audio" in kwargs.get("modalities", []):
if kwargs.get("audio", {}).get("language") is None: if kwargs.get("audio", {}).get("language") is None: