mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
Refactor Cohere provider to update API endpoint and improve model retrieval logic
This commit is contained in:
parent
86e6cd8c0c
commit
a1c3ed72c2
1 changed files with 42 additions and 84 deletions
|
|
@ -1,42 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
import requests
|
||||
|
||||
from ..helper import filter_none
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...requests import StreamSession, raise_for_status, sse_stream
|
||||
from ...providers.response import FinishReason, Usage
|
||||
from ...errors import MissingAuthError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...tools.run_tools import AuthManager
|
||||
from ... import debug
|
||||
|
||||
class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Cohere API"
|
||||
url = "https://cohere.com"
|
||||
login_url = "https://dashboard.cohere.com/api-keys"
|
||||
api_base = "https://api.cohere.ai/v1"
|
||||
api_endpoint = "https://api.cohere.ai/v2/chat"
|
||||
working = True
|
||||
active_by_default = True
|
||||
needs_auth = True
|
||||
supports_stream = True
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
||||
default_model = "command-r-plus"
|
||||
models = [
|
||||
default_model,
|
||||
"command-r",
|
||||
"command",
|
||||
"command-nightly",
|
||||
"command-light",
|
||||
"command-light-nightly",
|
||||
]
|
||||
|
||||
model_aliases = {
|
||||
"command-r-plus-08-2024": "command-r-plus",
|
||||
"command-r-08-2024": "command-r",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs):
|
||||
if not cls.models:
|
||||
url = "https://api.cohere.com/v1/models?page_size=500&endpoint=chat"
|
||||
models = requests.get(url).json().get("models", [])
|
||||
cls.models = [model.get("name") for model in models if "chat" in model.get("endpoints")]
|
||||
cls.vision_models = {model.get("name") for model in models if model.get("supports_vision")}
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
|
|
@ -51,43 +46,14 @@ class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
stop: list[str] = None,
|
||||
stream: bool = False,
|
||||
stream: bool = True,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if api_key is None:
|
||||
api_key = AuthManager.load_api_key(cls)
|
||||
if api_key is None:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
|
||||
# Convert messages to Cohere format
|
||||
system_message = None
|
||||
chat_history = []
|
||||
user_message = None
|
||||
|
||||
# Filter out system messages first
|
||||
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
||||
if system_messages:
|
||||
system_message = "\n".join([msg.get("content", "") for msg in system_messages])
|
||||
|
||||
# Process conversation messages (non-system)
|
||||
conversation_messages = [msg for msg in messages if msg.get("role") != "system"]
|
||||
|
||||
# The last message should be from user
|
||||
if conversation_messages and conversation_messages[-1].get("role") == "user":
|
||||
user_message = conversation_messages[-1].get("content", "")
|
||||
# All previous messages become chat history
|
||||
for msg in conversation_messages[:-1]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
chat_history.append({"role": "USER", "message": content})
|
||||
elif role == "assistant":
|
||||
chat_history.append({"role": "CHATBOT", "message": content})
|
||||
else:
|
||||
raise ValueError("The last message must be from the user")
|
||||
|
||||
async with StreamSession(
|
||||
proxy=proxy,
|
||||
headers=cls.get_headers(stream, api_key, headers),
|
||||
|
|
@ -95,19 +61,16 @@ class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
impersonate=impersonate,
|
||||
) as session:
|
||||
data = filter_none(
|
||||
message=user_message,
|
||||
messages=messages,
|
||||
model=cls.get_model(model, api_key=api_key),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
k=top_k,
|
||||
p=top_p,
|
||||
stop_sequences=stop,
|
||||
preamble=system_message,
|
||||
chat_history=chat_history if chat_history else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async with session.post(f"{cls.api_base}/chat", json=data) as response:
|
||||
async with session.post(cls.api_endpoint, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
if not stream:
|
||||
|
|
@ -120,40 +83,35 @@ class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
yield FinishReason("stop")
|
||||
elif data["finish_reason"] == "MAX_TOKENS":
|
||||
yield FinishReason("length")
|
||||
if "meta" in data and "tokens" in data["meta"]:
|
||||
if "usage" in data:
|
||||
tokens = data.get("usage", {}).get("tokens", {})
|
||||
yield Usage(
|
||||
prompt_tokens=data["meta"]["tokens"]["input_tokens"],
|
||||
completion_tokens=data["meta"]["tokens"]["output_tokens"],
|
||||
total_tokens=data["meta"]["tokens"]["input_tokens"] + data["meta"]["tokens"]["output_tokens"]
|
||||
prompt_tokens=tokens.get("input_tokens"),
|
||||
completion_tokens=tokens.get("output_tokens"),
|
||||
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
|
||||
billed_units=data.get("usage", {}).get("billed_units")
|
||||
)
|
||||
else:
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == b"[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
cls.raise_error(data)
|
||||
|
||||
if "event_type" in data:
|
||||
if data["event_type"] == "text-generation":
|
||||
if "text" in data:
|
||||
yield data["text"]
|
||||
elif data["event_type"] == "stream-end":
|
||||
if "finish_reason" in data:
|
||||
if data["finish_reason"] == "COMPLETE":
|
||||
yield FinishReason("stop")
|
||||
elif data["finish_reason"] == "MAX_TOKENS":
|
||||
yield FinishReason("length")
|
||||
if "meta" in data and "tokens" in data["meta"]:
|
||||
yield Usage(
|
||||
prompt_tokens=data["meta"]["tokens"]["input_tokens"],
|
||||
completion_tokens=data["meta"]["tokens"]["output_tokens"],
|
||||
total_tokens=data["meta"]["tokens"]["input_tokens"] + data["meta"]["tokens"]["output_tokens"]
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
async for data in sse_stream(response):
|
||||
cls.raise_error(data)
|
||||
if "type" in data:
|
||||
if data["type"] == "content-delta":
|
||||
yield data.get("delta", {}).get("message", {}).get("content", {}).get("text")
|
||||
elif data["type"] == "message-end":
|
||||
delta = data.get("delta", {})
|
||||
if "finish_reason" in delta:
|
||||
if delta["finish_reason"] == "COMPLETE":
|
||||
yield FinishReason("stop")
|
||||
elif delta["finish_reason"] == "MAX_TOKENS":
|
||||
yield FinishReason("length")
|
||||
if "usage" in delta:
|
||||
tokens = delta.get("usage", {}).get("tokens", {})
|
||||
yield Usage(
|
||||
prompt_tokens=tokens.get("input_tokens"),
|
||||
completion_tokens=tokens.get("output_tokens"),
|
||||
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
|
||||
billed_units=delta.get("usage", {}).get("billed_units")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue