Add get_models to GeminiPro provider

This commit is contained in:
Heiner Lohaus 2024-12-16 01:59:30 +01:00
parent 68c7a92ee2
commit ec9df59828
6 changed files with 44 additions and 24 deletions

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import uuid
from ..typing import AsyncResult, Messages, Cookies from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
@ -37,18 +36,16 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
if not cls.models: if not cls.models:
if cls._args is None: if cls._args is None:
get_running_loop(check_nested=True) get_running_loop(check_nested=True)
args = get_args_from_nodriver(cls.url, cookies={ args = get_args_from_nodriver(cls.url)
'__cf_bm': uuid.uuid4().hex,
})
cls._args = asyncio.run(args) cls._args = asyncio.run(args)
with Session(**cls._args) as session: with Session(**cls._args) as session:
response = session.get(cls.models_url) response = session.get(cls.models_url)
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response) cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
try: try:
raise_for_status(response) raise_for_status(response)
except ResponseStatusError as e: except ResponseStatusError:
cls._args = None cls._args = None
raise e raise
json_data = response.json() json_data = response.json()
cls.models = [model.get("name") for model in json_data.get("models")] cls.models = [model.get("name") for model in json_data.get("models")]
return cls.models return cls.models
@ -64,9 +61,9 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
timeout: int = 300, timeout: int = 300,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
model = cls.get_model(model)
if cls._args is None: if cls._args is None:
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies) cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
model = cls.get_model(model)
data = { data = {
"messages": messages, "messages": messages,
"lora": None, "lora": None,

View file

@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI):
} }
@classmethod @classmethod
def get_models(cls): def get_models(cls, **kwargs):
if not hasattr(cls, 'image_models'): if not hasattr(cls, 'image_models'):
cls.image_models = [] cls.image_models = []
if not cls.image_models: if not cls.image_models:

View file

@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI):
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct" default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
@classmethod @classmethod
def get_models(cls): def get_models(cls, **kwargs):
if not cls.models: if not cls.models:
url = 'https://api.deepinfra.com/models/featured' url = 'https://api.deepinfra.com/models/featured'
models = requests.get(url).json() models = requests.get(url).json()

View file

@ -2,30 +2,52 @@ from __future__ import annotations
import base64 import base64
import json import json
import requests
from aiohttp import ClientSession, BaseConnector from aiohttp import ClientSession, BaseConnector
from ...typing import AsyncResult, Messages, ImagesType from ...typing import AsyncResult, Messages, ImagesType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...image import to_bytes, is_accepted_format from ...image import to_bytes, is_accepted_format
from ...errors import MissingAuthError from ...errors import MissingAuthError
from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_connector from ..helper import get_connector
from ... import debug
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API" label = "Google Gemini API"
url = "https://ai.google.dev" url = "https://ai.google.dev"
api_base = "https://generativelanguage.googleapis.com/v1beta"
working = True working = True
supports_message_history = True supports_message_history = True
needs_auth = True needs_auth = True
default_model = "gemini-1.5-pro" default_model = "gemini-1.5-pro"
default_vision_model = default_model default_vision_model = default_model
models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
model_aliases = { model_aliases = {
"gemini-flash": "gemini-1.5-flash", "gemini-flash": "gemini-1.5-flash",
"gemini-flash": "gemini-1.5-flash-8b", "gemini-flash": "gemini-1.5-flash-8b",
} }
@classmethod
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
response = requests.get(f"{api_base}/models?key={api_key}")
raise_for_status(response)
data = response.json()
cls.models = [
model.get("name").split("/").pop()
for model in data.get("models")
if "generateContent" in model.get("supportedGenerationMethods")
]
cls.models.sort()
except Exception as e:
debug.log(e)
cls.models = cls.fallback_models
return cls.models
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
cls, cls,
@ -34,17 +56,17 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
stream: bool = False, stream: bool = False,
proxy: str = None, proxy: str = None,
api_key: str = None, api_key: str = None,
api_base: str = "https://generativelanguage.googleapis.com/v1beta", api_base: str = api_base,
use_auth_header: bool = False, use_auth_header: bool = False,
images: ImagesType = None, images: ImagesType = None,
connector: BaseConnector = None, connector: BaseConnector = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
model = cls.get_model(model)
if not api_key: if not api_key:
raise MissingAuthError('Add a "api_key"') raise MissingAuthError('Add a "api_key"')
model = cls.get_model(model, api_key=api_key, api_base=api_base)
headers = params = None headers = params = None
if use_auth_header: if use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"} headers = {"Authorization": f"Bearer {api_key}"}

View file

@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
fallback_models = [] fallback_models = []
@classmethod @classmethod
def get_models(cls, api_key: str = None): def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models: if not cls.models:
try: try:
headers = {} headers = {}
if api_key is not None: if api_key is not None:
headers["authorization"] = f"Bearer {api_key}" headers["authorization"] = f"Bearer {api_key}"
response = requests.get(f"{cls.api_base}/models", headers=headers) response = requests.get(f"{api_base}/models", headers=headers)
raise_for_status(response) raise_for_status(response)
data = response.json() data = response.json()
cls.models = [model.get("id") for model in data.get("data")] cls.models = [model.get("id") for model in data.get("data")]
@ -82,7 +82,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
) as session: ) as session:
data = filter_none( data = filter_none(
messages=messages, messages=messages,
model=cls.get_model(model), model=cls.get_model(model, api_key=api_key, api_base=api_base),
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
@ -147,4 +147,4 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None else {} if api_key is not None else {}
), ),
**({} if headers is None else headers) **({} if headers is None else headers)
} }

View file

@ -243,19 +243,20 @@ class ProviderModelMixin:
last_model: str = None last_model: str = None
@classmethod @classmethod
def get_models(cls) -> list[str]: def get_models(cls, **kwargs) -> list[str]:
if not cls.models and cls.default_model is not None: if not cls.models and cls.default_model is not None:
return [cls.default_model] return [cls.default_model]
return cls.models return cls.models
@classmethod @classmethod
def get_model(cls, model: str) -> str: def get_model(cls, model: str, **kwargs) -> str:
if not model and cls.default_model is not None: if not model and cls.default_model is not None:
model = cls.default_model model = cls.default_model
elif model in cls.model_aliases: elif model in cls.model_aliases:
model = cls.model_aliases[model] model = cls.model_aliases[model]
elif model not in cls.get_models() and cls.models: else:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") if model not in cls.get_models(**kwargs) and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
cls.last_model = model cls.last_model = model
debug.last_model = model debug.last_model = model
return model return model