Update provider parameters, check for valid provider

Fix reading model list in GeminiPro
Fix  check content-type in OpenaiAPI
This commit is contained in:
hlohaus 2025-01-24 09:45:40 +01:00
parent 9d6777e239
commit fd5fa8a4eb
5 changed files with 13 additions and 8 deletions

View file

@ -23,6 +23,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
working = True
supports_message_history = True
supports_system_message = True
needs_auth = True
default_model = "gemini-1.5-pro"
@ -39,7 +40,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
response = requests.get(f"{api_base}/models?key={api_key}")
url = f"{cls.api_base if not api_base else api_base}/models"
response = requests.get(url, params={"key": api_key})
raise_for_status(response)
data = response.json()
cls.models = [
@ -50,7 +52,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
cls.models.sort()
except Exception as e:
debug.log(e)
cls.models = cls.fallback_models
return cls.fallback_models
return cls.models
@classmethod

View file

@ -108,7 +108,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
if api_endpoint is None:
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
async with session.post(api_endpoint, json=data) as response:
if response.headers.get("content-type", None if stream else "application/json") == "application/json":
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
if content_type.startswith("application/json"):
data = await response.json()
cls.raise_error(data)
await raise_for_status(response)
@ -122,7 +123,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
return
elif response.headers.get("content-type", "text/event-stream" if stream else None) == "text/event-stream":
elif content_type.startswith("text/event-stream"):
await raise_for_status(response)
first = True
async for line in response.iter_lines():
@ -147,7 +148,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
break
else:
await raise_for_status(response)
raise ResponseError(f"Not supported content-type: {response.headers.get('content-type')}")
raise ResponseError(f"Not supported content-type: {content_type}")
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:

View file

@ -839,7 +839,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
await api("conversation", {
id: message_id,
conversation_id: window.conversation_id,
conversation: conversation.data && provider in conversation.data ? conversation.data[provider] : null,
conversation: provider && conversation.data && provider in conversation.data ? conversation.data[provider] : null,
model: model,
web_search: switchInput.checked,
provider: provider,

View file

@ -62,7 +62,7 @@ class Api:
"name": provider.__name__,
"label": provider.label if hasattr(provider, "label") else provider.__name__,
"parent": getattr(provider, "parent", None),
"image": getattr(provider, "image_models", None) is not None,
"image": bool(getattr(provider, "image_models", False)),
"vision": getattr(provider, "default_vision_model", None) is not None,
"auth": provider.needs_auth,
"login_url": getattr(provider, "login_url", None),
@ -157,7 +157,6 @@ class Api:
**(provider_handler.get_parameters(as_json=True) if hasattr(provider_handler, "get_parameters") else {}),
"model": model,
"messages": kwargs.get("messages"),
"web_search": kwargs.get("web_search")
}
if isinstance(kwargs.get("conversation"), JsonConversation):
params["conversation"] = kwargs.get("conversation").get_dict()

View file

@ -34,6 +34,7 @@ SAFE_PARAMETERS = [
"api_key", "api_base", "seed", "width", "height",
"proof_token", "max_retries", "web_search",
"guidance_scale", "num_inference_steps", "randomize_seed",
"safe", "enhance", "private",
]
BASIC_PARAMETERS = {
@ -61,6 +62,8 @@ PARAMETER_EXAMPLES = {
"max_new_tokens": 1024,
"max_tokens": 4096,
"seed": 42,
"stop": ["stop1", "stop2"],
"tools": [],
}
class AbstractProvider(BaseProvider):