mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Add get_models to GeminiPro provider
This commit is contained in:
parent
68c7a92ee2
commit
ec9df59828
6 changed files with 44 additions and 24 deletions
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from ..typing import AsyncResult, Messages, Cookies
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
|
||||
|
|
@ -37,18 +36,16 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if not cls.models:
|
||||
if cls._args is None:
|
||||
get_running_loop(check_nested=True)
|
||||
args = get_args_from_nodriver(cls.url, cookies={
|
||||
'__cf_bm': uuid.uuid4().hex,
|
||||
})
|
||||
args = get_args_from_nodriver(cls.url)
|
||||
cls._args = asyncio.run(args)
|
||||
with Session(**cls._args) as session:
|
||||
response = session.get(cls.models_url)
|
||||
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
|
||||
try:
|
||||
raise_for_status(response)
|
||||
except ResponseStatusError as e:
|
||||
except ResponseStatusError:
|
||||
cls._args = None
|
||||
raise e
|
||||
raise
|
||||
json_data = response.json()
|
||||
cls.models = [model.get("name") for model in json_data.get("models")]
|
||||
return cls.models
|
||||
|
|
@ -64,9 +61,9 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
timeout: int = 300,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
if cls._args is None:
|
||||
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
|
||||
model = cls.get_model(model)
|
||||
data = {
|
||||
"messages": messages,
|
||||
"lora": None,
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
def get_models(cls, **kwargs):
|
||||
if not hasattr(cls, 'image_models'):
|
||||
cls.image_models = []
|
||||
if not cls.image_models:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI):
|
|||
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
def get_models(cls, **kwargs):
|
||||
if not cls.models:
|
||||
url = 'https://api.deepinfra.com/models/featured'
|
||||
models = requests.get(url).json()
|
||||
|
|
|
|||
|
|
@ -2,30 +2,52 @@ from __future__ import annotations
|
|||
|
||||
import base64
|
||||
import json
|
||||
import requests
|
||||
from aiohttp import ClientSession, BaseConnector
|
||||
|
||||
from ...typing import AsyncResult, Messages, ImagesType
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...image import to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_connector
|
||||
from ... import debug
|
||||
|
||||
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Google Gemini API"
|
||||
url = "https://ai.google.dev"
|
||||
|
||||
api_base = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
working = True
|
||||
supports_message_history = True
|
||||
needs_auth = True
|
||||
|
||||
|
||||
default_model = "gemini-1.5-pro"
|
||||
default_vision_model = default_model
|
||||
models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
|
||||
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
|
||||
model_aliases = {
|
||||
"gemini-flash": "gemini-1.5-flash",
|
||||
"gemini-flash": "gemini-1.5-flash-8b",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
|
||||
if not cls.models:
|
||||
try:
|
||||
response = requests.get(f"{api_base}/models?key={api_key}")
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
cls.models = [
|
||||
model.get("name").split("/").pop()
|
||||
for model in data.get("models")
|
||||
if "generateContent" in model.get("supportedGenerationMethods")
|
||||
]
|
||||
cls.models.sort()
|
||||
except Exception as e:
|
||||
debug.log(e)
|
||||
cls.models = cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
|
|
@ -34,17 +56,17 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
stream: bool = False,
|
||||
proxy: str = None,
|
||||
api_key: str = None,
|
||||
api_base: str = "https://generativelanguage.googleapis.com/v1beta",
|
||||
api_base: str = api_base,
|
||||
use_auth_header: bool = False,
|
||||
images: ImagesType = None,
|
||||
connector: BaseConnector = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
|
||||
if not api_key:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
|
||||
model = cls.get_model(model, api_key=api_key, api_base=api_base)
|
||||
|
||||
headers = params = None
|
||||
if use_auth_header:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
fallback_models = []
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None):
|
||||
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
|
||||
if not cls.models:
|
||||
try:
|
||||
headers = {}
|
||||
if api_key is not None:
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
response = requests.get(f"{cls.api_base}/models", headers=headers)
|
||||
response = requests.get(f"{api_base}/models", headers=headers)
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
cls.models = [model.get("id") for model in data.get("data")]
|
||||
|
|
@ -82,7 +82,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
) as session:
|
||||
data = filter_none(
|
||||
messages=messages,
|
||||
model=cls.get_model(model),
|
||||
model=cls.get_model(model, api_key=api_key, api_base=api_base),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
|
|
@ -147,4 +147,4 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if api_key is not None else {}
|
||||
),
|
||||
**({} if headers is None else headers)
|
||||
}
|
||||
}
|
||||
|
|
@ -243,19 +243,20 @@ class ProviderModelMixin:
|
|||
last_model: str = None
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
def get_models(cls, **kwargs) -> list[str]:
|
||||
if not cls.models and cls.default_model is not None:
|
||||
return [cls.default_model]
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
if not model and cls.default_model is not None:
|
||||
model = cls.default_model
|
||||
elif model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
elif model not in cls.get_models() and cls.models:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
else:
|
||||
if model not in cls.get_models(**kwargs) and cls.models:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
cls.last_model = model
|
||||
debug.last_model = model
|
||||
return model
|
||||
Loading…
Add table
Add a link
Reference in a new issue