mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Add get_models to GeminiPro provider
This commit is contained in:
parent
68c7a92ee2
commit
ec9df59828
6 changed files with 44 additions and 24 deletions
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages, Cookies
|
from ..typing import AsyncResult, Messages, Cookies
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
|
||||||
|
|
@ -37,18 +36,16 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
if cls._args is None:
|
if cls._args is None:
|
||||||
get_running_loop(check_nested=True)
|
get_running_loop(check_nested=True)
|
||||||
args = get_args_from_nodriver(cls.url, cookies={
|
args = get_args_from_nodriver(cls.url)
|
||||||
'__cf_bm': uuid.uuid4().hex,
|
|
||||||
})
|
|
||||||
cls._args = asyncio.run(args)
|
cls._args = asyncio.run(args)
|
||||||
with Session(**cls._args) as session:
|
with Session(**cls._args) as session:
|
||||||
response = session.get(cls.models_url)
|
response = session.get(cls.models_url)
|
||||||
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
|
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
|
||||||
try:
|
try:
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
except ResponseStatusError as e:
|
except ResponseStatusError:
|
||||||
cls._args = None
|
cls._args = None
|
||||||
raise e
|
raise
|
||||||
json_data = response.json()
|
json_data = response.json()
|
||||||
cls.models = [model.get("name") for model in json_data.get("models")]
|
cls.models = [model.get("name") for model in json_data.get("models")]
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
@ -64,9 +61,9 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
timeout: int = 300,
|
timeout: int = 300,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
model = cls.get_model(model)
|
|
||||||
if cls._args is None:
|
if cls._args is None:
|
||||||
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
|
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
|
||||||
|
model = cls.get_model(model)
|
||||||
data = {
|
data = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"lora": None,
|
"lora": None,
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI):
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls):
|
def get_models(cls, **kwargs):
|
||||||
if not hasattr(cls, 'image_models'):
|
if not hasattr(cls, 'image_models'):
|
||||||
cls.image_models = []
|
cls.image_models = []
|
||||||
if not cls.image_models:
|
if not cls.image_models:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI):
|
||||||
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls):
|
def get_models(cls, **kwargs):
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
url = 'https://api.deepinfra.com/models/featured'
|
url = 'https://api.deepinfra.com/models/featured'
|
||||||
models = requests.get(url).json()
|
models = requests.get(url).json()
|
||||||
|
|
|
||||||
|
|
@ -2,30 +2,52 @@ from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import requests
|
||||||
from aiohttp import ClientSession, BaseConnector
|
from aiohttp import ClientSession, BaseConnector
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages, ImagesType
|
from ...typing import AsyncResult, Messages, ImagesType
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|
||||||
from ...image import to_bytes, is_accepted_format
|
from ...image import to_bytes, is_accepted_format
|
||||||
from ...errors import MissingAuthError
|
from ...errors import MissingAuthError
|
||||||
|
from ...requests.raise_for_status import raise_for_status
|
||||||
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import get_connector
|
from ..helper import get_connector
|
||||||
|
from ... import debug
|
||||||
|
|
||||||
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
label = "Google Gemini API"
|
label = "Google Gemini API"
|
||||||
url = "https://ai.google.dev"
|
url = "https://ai.google.dev"
|
||||||
|
api_base = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
working = True
|
working = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
needs_auth = True
|
needs_auth = True
|
||||||
|
|
||||||
default_model = "gemini-1.5-pro"
|
default_model = "gemini-1.5-pro"
|
||||||
default_vision_model = default_model
|
default_vision_model = default_model
|
||||||
models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
|
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"gemini-flash": "gemini-1.5-flash",
|
"gemini-flash": "gemini-1.5-flash",
|
||||||
"gemini-flash": "gemini-1.5-flash-8b",
|
"gemini-flash": "gemini-1.5-flash-8b",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
|
||||||
|
if not cls.models:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{api_base}/models?key={api_key}")
|
||||||
|
raise_for_status(response)
|
||||||
|
data = response.json()
|
||||||
|
cls.models = [
|
||||||
|
model.get("name").split("/").pop()
|
||||||
|
for model in data.get("models")
|
||||||
|
if "generateContent" in model.get("supportedGenerationMethods")
|
||||||
|
]
|
||||||
|
cls.models.sort()
|
||||||
|
except Exception as e:
|
||||||
|
debug.log(e)
|
||||||
|
cls.models = cls.fallback_models
|
||||||
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -34,17 +56,17 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
api_base: str = "https://generativelanguage.googleapis.com/v1beta",
|
api_base: str = api_base,
|
||||||
use_auth_header: bool = False,
|
use_auth_header: bool = False,
|
||||||
images: ImagesType = None,
|
images: ImagesType = None,
|
||||||
connector: BaseConnector = None,
|
connector: BaseConnector = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
model = cls.get_model(model)
|
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise MissingAuthError('Add a "api_key"')
|
raise MissingAuthError('Add a "api_key"')
|
||||||
|
|
||||||
|
model = cls.get_model(model, api_key=api_key, api_base=api_base)
|
||||||
|
|
||||||
headers = params = None
|
headers = params = None
|
||||||
if use_auth_header:
|
if use_auth_header:
|
||||||
headers = {"Authorization": f"Bearer {api_key}"}
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
fallback_models = []
|
fallback_models = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls, api_key: str = None):
|
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
try:
|
try:
|
||||||
headers = {}
|
headers = {}
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["authorization"] = f"Bearer {api_key}"
|
headers["authorization"] = f"Bearer {api_key}"
|
||||||
response = requests.get(f"{cls.api_base}/models", headers=headers)
|
response = requests.get(f"{api_base}/models", headers=headers)
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
cls.models = [model.get("id") for model in data.get("data")]
|
cls.models = [model.get("id") for model in data.get("data")]
|
||||||
|
|
@ -82,7 +82,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
) as session:
|
) as session:
|
||||||
data = filter_none(
|
data = filter_none(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=cls.get_model(model),
|
model=cls.get_model(model, api_key=api_key, api_base=api_base),
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
|
@ -147,4 +147,4 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if api_key is not None else {}
|
if api_key is not None else {}
|
||||||
),
|
),
|
||||||
**({} if headers is None else headers)
|
**({} if headers is None else headers)
|
||||||
}
|
}
|
||||||
|
|
@ -243,19 +243,20 @@ class ProviderModelMixin:
|
||||||
last_model: str = None
|
last_model: str = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls) -> list[str]:
|
def get_models(cls, **kwargs) -> list[str]:
|
||||||
if not cls.models and cls.default_model is not None:
|
if not cls.models and cls.default_model is not None:
|
||||||
return [cls.default_model]
|
return [cls.default_model]
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str) -> str:
|
def get_model(cls, model: str, **kwargs) -> str:
|
||||||
if not model and cls.default_model is not None:
|
if not model and cls.default_model is not None:
|
||||||
model = cls.default_model
|
model = cls.default_model
|
||||||
elif model in cls.model_aliases:
|
elif model in cls.model_aliases:
|
||||||
model = cls.model_aliases[model]
|
model = cls.model_aliases[model]
|
||||||
elif model not in cls.get_models() and cls.models:
|
else:
|
||||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
if model not in cls.get_models(**kwargs) and cls.models:
|
||||||
|
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||||
cls.last_model = model
|
cls.last_model = model
|
||||||
debug.last_model = model
|
debug.last_model = model
|
||||||
return model
|
return model
|
||||||
Loading…
Add table
Add a link
Reference in a new issue