mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Update provider parameters, check for valid provider
Fix reading model list in GeminiPro Fix check content-type in OpenaiAPI
This commit is contained in:
parent
9d6777e239
commit
fd5fa8a4eb
5 changed files with 13 additions and 8 deletions
|
|
@ -23,6 +23,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
|
||||
working = True
|
||||
supports_message_history = True
|
||||
supports_system_message = True
|
||||
needs_auth = True
|
||||
|
||||
default_model = "gemini-1.5-pro"
|
||||
|
|
@ -39,7 +40,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
|
||||
if not cls.models:
|
||||
try:
|
||||
response = requests.get(f"{api_base}/models?key={api_key}")
|
||||
url = f"{cls.api_base if not api_base else api_base}/models"
|
||||
response = requests.get(url, params={"key": api_key})
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
cls.models = [
|
||||
|
|
@ -50,7 +52,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
cls.models.sort()
|
||||
except Exception as e:
|
||||
debug.log(e)
|
||||
cls.models = cls.fallback_models
|
||||
return cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -108,7 +108,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
if api_endpoint is None:
|
||||
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
|
||||
async with session.post(api_endpoint, json=data) as response:
|
||||
if response.headers.get("content-type", None if stream else "application/json") == "application/json":
|
||||
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
|
||||
if content_type.startswith("application/json"):
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
await raise_for_status(response)
|
||||
|
|
@ -122,7 +123,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
yield FinishReason(choice["finish_reason"])
|
||||
return
|
||||
elif response.headers.get("content-type", "text/event-stream" if stream else None) == "text/event-stream":
|
||||
elif content_type.startswith("text/event-stream"):
|
||||
await raise_for_status(response)
|
||||
first = True
|
||||
async for line in response.iter_lines():
|
||||
|
|
@ -147,7 +148,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
|||
break
|
||||
else:
|
||||
await raise_for_status(response)
|
||||
raise ResponseError(f"Not supported content-type: {response.headers.get('content-type')}")
|
||||
raise ResponseError(f"Not supported content-type: {content_type}")
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
|
|
|
|||
|
|
@ -839,7 +839,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
await api("conversation", {
|
||||
id: message_id,
|
||||
conversation_id: window.conversation_id,
|
||||
conversation: conversation.data && provider in conversation.data ? conversation.data[provider] : null,
|
||||
conversation: provider && conversation.data && provider in conversation.data ? conversation.data[provider] : null,
|
||||
model: model,
|
||||
web_search: switchInput.checked,
|
||||
provider: provider,
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class Api:
|
|||
"name": provider.__name__,
|
||||
"label": provider.label if hasattr(provider, "label") else provider.__name__,
|
||||
"parent": getattr(provider, "parent", None),
|
||||
"image": getattr(provider, "image_models", None) is not None,
|
||||
"image": bool(getattr(provider, "image_models", False)),
|
||||
"vision": getattr(provider, "default_vision_model", None) is not None,
|
||||
"auth": provider.needs_auth,
|
||||
"login_url": getattr(provider, "login_url", None),
|
||||
|
|
@ -157,7 +157,6 @@ class Api:
|
|||
**(provider_handler.get_parameters(as_json=True) if hasattr(provider_handler, "get_parameters") else {}),
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages"),
|
||||
"web_search": kwargs.get("web_search")
|
||||
}
|
||||
if isinstance(kwargs.get("conversation"), JsonConversation):
|
||||
params["conversation"] = kwargs.get("conversation").get_dict()
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ SAFE_PARAMETERS = [
|
|||
"api_key", "api_base", "seed", "width", "height",
|
||||
"proof_token", "max_retries", "web_search",
|
||||
"guidance_scale", "num_inference_steps", "randomize_seed",
|
||||
"safe", "enhance", "private",
|
||||
]
|
||||
|
||||
BASIC_PARAMETERS = {
|
||||
|
|
@ -61,6 +62,8 @@ PARAMETER_EXAMPLES = {
|
|||
"max_new_tokens": 1024,
|
||||
"max_tokens": 4096,
|
||||
"seed": 42,
|
||||
"stop": ["stop1", "stop2"],
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue