Refactor Cohere provider to update API endpoint and improve model retrieval logic

This commit is contained in:
hlohaus 2025-09-04 13:01:15 +02:00
parent 86e6cd8c0c
commit a1c3ed72c2

View file

@ -1,42 +1,37 @@
from __future__ import annotations
import json
from typing import Optional
import requests
from ..helper import filter_none
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status
from ...requests import StreamSession, raise_for_status, sse_stream
from ...providers.response import FinishReason, Usage
from ...errors import MissingAuthError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...tools.run_tools import AuthManager
from ... import debug
class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cohere API"
url = "https://cohere.com"
login_url = "https://dashboard.cohere.com/api-keys"
api_base = "https://api.cohere.ai/v1"
api_endpoint = "https://api.cohere.ai/v2/chat"
working = True
active_by_default = True
needs_auth = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "command-r-plus"
models = [
default_model,
"command-r",
"command",
"command-nightly",
"command-light",
"command-light-nightly",
]
model_aliases = {
"command-r-plus-08-2024": "command-r-plus",
"command-r-08-2024": "command-r",
}
@classmethod
def get_models(cls, **kwargs):
if not cls.models:
url = "https://api.cohere.com/v1/models?page_size=500&endpoint=chat"
models = requests.get(url).json().get("models", [])
cls.models = [model.get("name") for model in models if "chat" in model.get("endpoints")]
cls.vision_models = {model.get("name") for model in models if model.get("supports_vision")}
return cls.models
@classmethod
async def create_async_generator(
@ -51,43 +46,14 @@ class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
top_k: int = None,
top_p: float = None,
stop: list[str] = None,
stream: bool = False,
stream: bool = True,
headers: dict = None,
impersonate: str = None,
**kwargs
) -> AsyncResult:
if api_key is None:
api_key = AuthManager.load_api_key(cls)
if api_key is None:
raise MissingAuthError('Add a "api_key"')
# Convert messages to Cohere format
system_message = None
chat_history = []
user_message = None
# Filter out system messages first
system_messages = [msg for msg in messages if msg.get("role") == "system"]
if system_messages:
system_message = "\n".join([msg.get("content", "") for msg in system_messages])
# Process conversation messages (non-system)
conversation_messages = [msg for msg in messages if msg.get("role") != "system"]
# The last message should be from user
if conversation_messages and conversation_messages[-1].get("role") == "user":
user_message = conversation_messages[-1].get("content", "")
# All previous messages become chat history
for msg in conversation_messages[:-1]:
role = msg.get("role")
content = msg.get("content", "")
if role == "user":
chat_history.append({"role": "USER", "message": content})
elif role == "assistant":
chat_history.append({"role": "CHATBOT", "message": content})
else:
raise ValueError("The last message must be from the user")
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
@ -95,19 +61,16 @@ class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
impersonate=impersonate,
) as session:
data = filter_none(
message=user_message,
messages=messages,
model=cls.get_model(model, api_key=api_key),
temperature=temperature,
max_tokens=max_tokens,
k=top_k,
p=top_p,
stop_sequences=stop,
preamble=system_message,
chat_history=chat_history if chat_history else None,
stream=stream,
)
async with session.post(f"{cls.api_base}/chat", json=data) as response:
async with session.post(cls.api_endpoint, json=data) as response:
await raise_for_status(response)
if not stream:
@ -120,40 +83,35 @@ class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
yield FinishReason("stop")
elif data["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "meta" in data and "tokens" in data["meta"]:
if "usage" in data:
tokens = data.get("usage", {}).get("tokens", {})
yield Usage(
prompt_tokens=data["meta"]["tokens"]["input_tokens"],
completion_tokens=data["meta"]["tokens"]["output_tokens"],
total_tokens=data["meta"]["tokens"]["input_tokens"] + data["meta"]["tokens"]["output_tokens"]
prompt_tokens=tokens.get("input_tokens"),
completion_tokens=tokens.get("output_tokens"),
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
billed_units=data.get("usage", {}).get("billed_units")
)
else:
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
try:
data = json.loads(chunk)
cls.raise_error(data)
if "event_type" in data:
if data["event_type"] == "text-generation":
if "text" in data:
yield data["text"]
elif data["event_type"] == "stream-end":
if "finish_reason" in data:
if data["finish_reason"] == "COMPLETE":
yield FinishReason("stop")
elif data["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "meta" in data and "tokens" in data["meta"]:
yield Usage(
prompt_tokens=data["meta"]["tokens"]["input_tokens"],
completion_tokens=data["meta"]["tokens"]["output_tokens"],
total_tokens=data["meta"]["tokens"]["input_tokens"] + data["meta"]["tokens"]["output_tokens"]
)
except json.JSONDecodeError:
continue
async for data in sse_stream(response):
cls.raise_error(data)
if "type" in data:
if data["type"] == "content-delta":
yield data.get("delta", {}).get("message", {}).get("content", {}).get("text")
elif data["type"] == "message-end":
delta = data.get("delta", {})
if "finish_reason" in delta:
if delta["finish_reason"] == "COMPLETE":
yield FinishReason("stop")
elif delta["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "usage" in delta:
tokens = delta.get("usage", {}).get("tokens", {})
yield Usage(
prompt_tokens=tokens.get("input_tokens"),
completion_tokens=tokens.get("output_tokens"),
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
billed_units=delta.get("usage", {}).get("billed_units")
)
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: