mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Enhance Perplexity provider with additional models and improved conversation handling; add JsonRequest and JsonResponse classes for better response management
This commit is contained in:
parent
877d5cce42
commit
ddadc36fa8
6 changed files with 179 additions and 138 deletions
|
|
@ -6,7 +6,7 @@ import uuid
|
||||||
from ..typing import AsyncResult, Messages, Cookies
|
from ..typing import AsyncResult, Messages, Cookies
|
||||||
from ..requests import StreamSession, raise_for_status, sse_stream
|
from ..requests import StreamSession, raise_for_status, sse_stream
|
||||||
from ..cookies import get_cookies
|
from ..cookies import get_cookies
|
||||||
from ..providers.response import ProviderInfo
|
from ..providers.response import ProviderInfo, JsonConversation, JsonRequest, JsonResponse, Reasoning
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
|
|
@ -21,21 +21,63 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
models = [
|
models = [
|
||||||
default_model,
|
default_model,
|
||||||
"turbo",
|
"turbo",
|
||||||
|
"gpt41",
|
||||||
|
"gpt5",
|
||||||
|
"gpt5_thinking",
|
||||||
|
"o3",
|
||||||
|
"o3pro",
|
||||||
|
"claude2",
|
||||||
|
"claude37sonnetthinking",
|
||||||
|
"claude40opus",
|
||||||
|
"claude40opusthinking",
|
||||||
|
"claude41opusthinking",
|
||||||
|
"claude45sonnet",
|
||||||
|
"claude45sonnetthinking",
|
||||||
|
"experimental",
|
||||||
|
"grok",
|
||||||
|
"grok4",
|
||||||
|
"gemini2flash",
|
||||||
"pplx_pro",
|
"pplx_pro",
|
||||||
"gpt-5",
|
"pplx_pro_upgraded",
|
||||||
|
"pplx_alpha",
|
||||||
|
"pplx_beta",
|
||||||
|
"comet_max_assistant",
|
||||||
|
"o3_research",
|
||||||
|
"o3pro_research",
|
||||||
|
"claude40sonnet_research",
|
||||||
|
"claude40sonnetthinking_research",
|
||||||
|
"claude40opus_research",
|
||||||
|
"claude40opusthinking_research",
|
||||||
|
"o3_labs",
|
||||||
|
"o3pro_labs",
|
||||||
|
"claude40sonnetthinking_labs",
|
||||||
|
"claude40opusthinking_labs",
|
||||||
|
"o4mini",
|
||||||
|
"o1",
|
||||||
|
"gpt4o",
|
||||||
|
"gpt45",
|
||||||
|
"gpt4",
|
||||||
|
"o3mini",
|
||||||
|
"claude35haiku",
|
||||||
|
"llama_x_large",
|
||||||
|
"mistral",
|
||||||
|
"claude3opus",
|
||||||
|
"gemini",
|
||||||
|
"pplx_reasoning",
|
||||||
|
"r1"
|
||||||
]
|
]
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"gpt-5": "gpt5",
|
"gpt-5": "gpt5",
|
||||||
|
"gpt-5-thinking": "gpt5_thinking",
|
||||||
}
|
}
|
||||||
|
|
||||||
_user_id = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
cookies: Cookies = None,
|
cookies: Cookies = None,
|
||||||
|
conversation: JsonConversation = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
|
|
@ -43,13 +85,13 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
model = cls.default_model
|
model = cls.default_model
|
||||||
if cookies is None:
|
if cookies is None:
|
||||||
cookies = get_cookies(cls.cookie_domain, False)
|
cookies = get_cookies(cls.cookie_domain, False)
|
||||||
else:
|
if conversation is None:
|
||||||
cls._user_id = None
|
conversation = JsonConversation(
|
||||||
|
frontend_uid=str(uuid.uuid4()),
|
||||||
# Generate UUIDs for request tracking
|
frontend_context_uuid=str(uuid.uuid4()),
|
||||||
frontend_uuid = str(uuid.uuid4())
|
visitor_id=str(uuid.uuid4()),
|
||||||
frontend_context_uuid = str(uuid.uuid4())
|
user_id=None,
|
||||||
visitor_id = str(uuid.uuid4())
|
)
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
|
@ -72,18 +114,19 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
break
|
break
|
||||||
|
|
||||||
async with StreamSession(headers=headers, cookies=cookies, proxy=proxy, impersonate="chrome") as session:
|
async with StreamSession(headers=headers, cookies=cookies, proxy=proxy, impersonate="chrome") as session:
|
||||||
if cls._user_id is None:
|
if conversation.user_id is None:
|
||||||
async with session.get(f"{cls.url}/api/auth/session") as response:
|
async with session.get(f"{cls.url}/api/auth/session") as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
user = await response.json()
|
user = await response.json()
|
||||||
cls._user_id = user.get("user", {}).get("id")
|
conversation.user_id = user.get("user", {}).get("id")
|
||||||
debug.log(f"Perplexity user id: {cls._user_id}")
|
debug.log(f"Perplexity user id: {conversation.user_id}")
|
||||||
|
yield conversation
|
||||||
if model == "auto":
|
if model == "auto":
|
||||||
model = "pplx_pro" if cls._user_id else "turbo"
|
model = "pplx_pro" if conversation.user_id else "turbo"
|
||||||
yield ProviderInfo(**cls.get_dict(), model=model)
|
yield ProviderInfo(**cls.get_dict(), model=model)
|
||||||
if model in cls.model_aliases:
|
if model in cls.model_aliases:
|
||||||
model = cls.model_aliases[model]
|
model = cls.model_aliases[model]
|
||||||
if cls._user_id is None:
|
if conversation.user_id is None:
|
||||||
data = {
|
data = {
|
||||||
"params": {
|
"params": {
|
||||||
"attachments": [],
|
"attachments": [],
|
||||||
|
|
@ -92,13 +135,13 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"search_focus": "internet",
|
"search_focus": "internet",
|
||||||
"sources": ["web"],
|
"sources": ["web"],
|
||||||
"search_recency_filter": None,
|
"search_recency_filter": None,
|
||||||
"frontend_uuid": frontend_uuid,
|
"frontend_uuid": conversation.frontend_uid,
|
||||||
"mode": "concise",
|
"mode": "concise",
|
||||||
"model_preference": model,
|
"model_preference": model,
|
||||||
"is_related_query": False,
|
"is_related_query": False,
|
||||||
"is_sponsored": False,
|
"is_sponsored": False,
|
||||||
"visitor_id": visitor_id,
|
"visitor_id": conversation.visitor_id,
|
||||||
"frontend_context_uuid": frontend_context_uuid,
|
"frontend_context_uuid": conversation.frontend_context_uuid,
|
||||||
"prompt_source": "user",
|
"prompt_source": "user",
|
||||||
"query_source": "home",
|
"query_source": "home",
|
||||||
"is_incognito": False,
|
"is_incognito": False,
|
||||||
|
|
@ -144,53 +187,92 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
data = {
|
data = {
|
||||||
"params": {
|
"params": {
|
||||||
"last_backend_uuid": None,
|
"last_backend_uuid": None,
|
||||||
"read_write_token": "457a2d3d-c53f-4065-8554-7645a36fc220",
|
"read_write_token": None,
|
||||||
"attachments": [],
|
"attachments": [],
|
||||||
"language": "en-US",
|
"language": "en-US",
|
||||||
"timezone": "America/New_York",
|
"timezone": "America/New_York",
|
||||||
"search_focus": "internet",
|
"search_focus": "internet",
|
||||||
"sources": ["web"],
|
"sources": [
|
||||||
"frontend_uuid": frontend_uuid,
|
"web"
|
||||||
|
],
|
||||||
|
"frontend_uuid": conversation.frontend_uid,
|
||||||
"mode": "copilot",
|
"mode": "copilot",
|
||||||
"model_preference": "gpt5",
|
"model_preference": model,
|
||||||
"is_related_query": False,
|
"is_related_query": False,
|
||||||
"is_sponsored": False,
|
"is_sponsored": False,
|
||||||
"visitor_id": visitor_id,
|
"visitor_id": conversation.visitor_id,
|
||||||
"user_nextauth_id": cls._user_id,
|
"user_nextauth_id": conversation.user_id,
|
||||||
"prompt_source": "user",
|
"prompt_source": "user",
|
||||||
"query_source":"followup",
|
"query_source": "followup",
|
||||||
"is_incognito": False,
|
"is_incognito": False,
|
||||||
"time_from_first_type": random.randint(0, 1000),
|
"time_from_first_type": random.randint(0, 1000),
|
||||||
"local_search_enabled": False,
|
"local_search_enabled": False,
|
||||||
"use_schematized_api": True,
|
"use_schematized_api": True,
|
||||||
"send_back_text_in_streaming_api": False,
|
"send_back_text_in_streaming_api": False,
|
||||||
"supported_block_use_cases": ["answer_modes", "media_items", "knowledge_cards", "inline_entity_cards", "place_widgets", "finance_widgets", "sports_widgets", "shopping_widgets", "jobs_widgets", "search_result_widgets", "clarification_responses", "inline_images", "inline_assets", "inline_finance_widgets", "placeholder_cards", "diff_blocks", "inline_knowledge_cards", "entity_group_v2", "refinement_filters", "canvas_mode"],
|
"supported_block_use_cases": [
|
||||||
|
"answer_modes",
|
||||||
|
"media_items",
|
||||||
|
"knowledge_cards",
|
||||||
|
"inline_entity_cards",
|
||||||
|
"place_widgets",
|
||||||
|
"finance_widgets",
|
||||||
|
"sports_widgets",
|
||||||
|
"shopping_widgets",
|
||||||
|
"jobs_widgets",
|
||||||
|
"search_result_widgets",
|
||||||
|
"clarification_responses",
|
||||||
|
"inline_images",
|
||||||
|
"inline_assets",
|
||||||
|
"inline_finance_widgets",
|
||||||
|
"placeholder_cards",
|
||||||
|
"diff_blocks",
|
||||||
|
"inline_knowledge_cards",
|
||||||
|
"entity_group_v2",
|
||||||
|
"refinement_filters",
|
||||||
|
"canvas_mode"
|
||||||
|
],
|
||||||
"client_coordinates": None,
|
"client_coordinates": None,
|
||||||
"mentions": [],
|
"mentions": [],
|
||||||
"skip_search_enabled": True,
|
"skip_search_enabled": True,
|
||||||
"is_nav_suggestions_disabled": False,
|
"is_nav_suggestions_disabled": False,
|
||||||
"followup_source": "link",
|
"followup_source": "link",
|
||||||
|
"always_search_override": False,
|
||||||
|
"override_no_search": False,
|
||||||
|
"comet_max_assistant_enabled": False,
|
||||||
"version": "2.18"
|
"version": "2.18"
|
||||||
},
|
},
|
||||||
"query_str": query
|
"query_str": query
|
||||||
}
|
}
|
||||||
|
yield JsonRequest.from_dict(data)
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{cls.url}/rest/sse/perplexity_ask",
|
f"{cls.url}/rest/sse/perplexity_ask",
|
||||||
json=data,
|
json=data,
|
||||||
) as response:
|
) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
full_response = ""
|
full_response = ""
|
||||||
last_response = ""
|
full_reasoning = ""
|
||||||
async for json_data in sse_stream(response):
|
async for json_data in sse_stream(response):
|
||||||
|
yield JsonResponse.from_dict(json_data)
|
||||||
for block in json_data.get("blocks", []):
|
for block in json_data.get("blocks", []):
|
||||||
for patch in block.get("diff_block", {}).get("patches", []):
|
for patch in block.get("diff_block", {}).get("patches", []):
|
||||||
|
if patch.get("path") == "/progress":
|
||||||
|
continue
|
||||||
value = patch.get("value", "")
|
value = patch.get("value", "")
|
||||||
|
if patch.get("path").startswith("/goals"):
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value.startswith(full_reasoning):
|
||||||
|
value = value[len(full_reasoning):]
|
||||||
|
yield Reasoning(value)
|
||||||
|
full_reasoning += value
|
||||||
|
else:
|
||||||
|
yield Reasoning(status="")
|
||||||
|
continue
|
||||||
|
if block.get("diff_block").get("field") != "markdown_block":
|
||||||
|
continue
|
||||||
value = value.get("answer", "") if isinstance(value, dict) else value
|
value = value.get("answer", "") if isinstance(value, dict) else value
|
||||||
if value:
|
if value and isinstance(value, str):
|
||||||
if value.startswith(full_response):
|
if value.startswith(full_response):
|
||||||
value = value[len(full_response):]
|
value = value[len(full_response):]
|
||||||
if value.startswith(last_response):
|
if value:
|
||||||
value = value[len(last_response):]
|
|
||||||
last_response = value
|
|
||||||
full_response += value
|
full_response += value
|
||||||
yield value
|
yield value
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from ..errors import MissingAuthError
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
from ..requests.aiohttp import get_connector
|
from ..requests.aiohttp import get_connector
|
||||||
from ..image import use_aspect_ratio
|
from ..image import use_aspect_ratio
|
||||||
from ..providers.response import ImageResponse, Reasoning, TitleGeneration, SuggestedFollowups
|
from ..providers.response import ImageResponse, Reasoning, TitleGeneration, SuggestedFollowups, JsonRequest
|
||||||
from ..tools.media import render_messages
|
from ..tools.media import render_messages
|
||||||
from ..config import STATIC_URL
|
from ..config import STATIC_URL
|
||||||
from .template.OpenaiTemplate import read_response
|
from .template.OpenaiTemplate import read_response
|
||||||
|
|
@ -461,6 +461,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
headers = {"referer": referrer}
|
headers = {"referer": referrer}
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["authorization"] = f"Bearer {api_key}"
|
headers["authorization"] = f"Bearer {api_key}"
|
||||||
|
yield JsonRequest.from_dict(data)
|
||||||
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
|
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
|
||||||
if response.status in (400, 500):
|
if response.status in (400, 500):
|
||||||
debug.error(f"Error: {response.status} - Bad Request: {data}")
|
debug.error(f"Error: {response.status} - Bad Request: {data}")
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||||
from ...requests import StreamSession, StreamResponse, raise_for_status, sse_stream
|
from ...requests import StreamSession, StreamResponse, raise_for_status, sse_stream
|
||||||
from ...image import use_aspect_ratio
|
from ...image import use_aspect_ratio
|
||||||
from ...image.copy_images import save_response_media
|
from ...image.copy_images import save_response_media
|
||||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo, AudioResponse, Reasoning, JsonConversation
|
from ...providers.response import *
|
||||||
from ...tools.media import render_messages
|
from ...tools.media import render_messages
|
||||||
from ...tools.run_tools import AuthManager
|
from ...tools.run_tools import AuthManager
|
||||||
from ...errors import MissingAuthError
|
from ...errors import MissingAuthError
|
||||||
|
|
@ -150,6 +150,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
|
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
|
||||||
if api_endpoint is None:
|
if api_endpoint is None:
|
||||||
api_endpoint = cls.api_endpoint
|
api_endpoint = cls.api_endpoint
|
||||||
|
yield JsonRequest.from_dict(data)
|
||||||
async with session.post(api_endpoint, json=data, ssl=cls.ssl) as response:
|
async with session.post(api_endpoint, json=data, ssl=cls.ssl) as response:
|
||||||
async for chunk in read_response(response, stream, prompt, cls.get_dict(), download_media):
|
async for chunk in read_response(response, stream, prompt, cls.get_dict(), download_media):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
@ -170,6 +171,7 @@ async def read_response(response: StreamResponse, stream: bool, prompt: str, pro
|
||||||
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
|
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
|
||||||
if content_type.startswith("application/json"):
|
if content_type.startswith("application/json"):
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
yield JsonResponse.from_dict(data)
|
||||||
OpenaiTemplate.raise_error(data, response.status)
|
OpenaiTemplate.raise_error(data, response.status)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
model = data.get("model")
|
model = data.get("model")
|
||||||
|
|
@ -206,6 +208,7 @@ async def read_response(response: StreamResponse, stream: bool, prompt: str, pro
|
||||||
first = True
|
first = True
|
||||||
model_returned = False
|
model_returned = False
|
||||||
async for data in sse_stream(response):
|
async for data in sse_stream(response):
|
||||||
|
yield JsonResponse.from_dict(data)
|
||||||
OpenaiTemplate.raise_error(data)
|
OpenaiTemplate.raise_error(data)
|
||||||
model = data.get("model")
|
model = data.get("model")
|
||||||
if not model_returned and model:
|
if not model_returned and model:
|
||||||
|
|
|
||||||
|
|
@ -261,53 +261,6 @@ async def async_iter_response(
|
||||||
finally:
|
finally:
|
||||||
await safe_aclose(response)
|
await safe_aclose(response)
|
||||||
|
|
||||||
async def async_response(
|
|
||||||
response: AsyncIterator[Union[str, ResponseType]]
|
|
||||||
) -> ClientResponse:
|
|
||||||
content = ""
|
|
||||||
response_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
|
||||||
idx = 0
|
|
||||||
usage = None
|
|
||||||
provider: ProviderInfo = None
|
|
||||||
conversation: JsonConversation = None
|
|
||||||
|
|
||||||
async for chunk in response:
|
|
||||||
if isinstance(chunk, FinishReason):
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, JsonConversation):
|
|
||||||
conversation = chunk
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, ToolCalls):
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, Usage):
|
|
||||||
usage = chunk
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, ProviderInfo):
|
|
||||||
provider = chunk
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, HiddenResponse):
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, Exception):
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = add_chunk(content, chunk)
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
if usage is None:
|
|
||||||
usage = UsageModel.model_construct(completion_tokens=idx, total_tokens=idx)
|
|
||||||
else:
|
|
||||||
usage = UsageModel.model_construct(**usage.get_dict())
|
|
||||||
|
|
||||||
response = ClientResponse.model_construct(
|
|
||||||
content, response_id, int(time.time()), usage=usage, conversation=conversation
|
|
||||||
)
|
|
||||||
if provider is not None:
|
|
||||||
response.provider = provider.name
|
|
||||||
response.model = provider.model
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def async_iter_append_model_and_provider(
|
async def async_iter_append_model_and_provider(
|
||||||
response: AsyncChatCompletionResponseType,
|
response: AsyncChatCompletionResponseType,
|
||||||
last_model: str,
|
last_model: str,
|
||||||
|
|
@ -361,6 +314,7 @@ class Completions:
|
||||||
stop: Optional[Union[list[str], str]] = None,
|
stop: Optional[Union[list[str], str]] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
ignore_stream: Optional[bool] = False,
|
ignore_stream: Optional[bool] = False,
|
||||||
|
raw: Optional[bool] = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> ChatCompletion:
|
) -> ChatCompletion:
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
|
|
@ -392,11 +346,20 @@ class Completions:
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if raw:
|
||||||
|
def filter_raw(response):
|
||||||
|
for chunk in response:
|
||||||
|
if isinstance(chunk, JsonResponse):
|
||||||
|
yield chunk
|
||||||
|
raw_response = filter_raw(response)
|
||||||
|
if stream:
|
||||||
|
return raw_response
|
||||||
|
return next(raw_response)
|
||||||
|
|
||||||
response = iter_response(response, stream, response_format, max_tokens, stop)
|
response = iter_response(response, stream, response_format, max_tokens, stop)
|
||||||
response = iter_append_model_and_provider(response, model, provider)
|
response = iter_append_model_and_provider(response, model, provider)
|
||||||
if stream:
|
if stream:
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
return next(response)
|
return next(response)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
|
|
@ -655,7 +618,6 @@ class AsyncClient(BaseClient):
|
||||||
self.models: ClientModels = ClientModels(self, provider, media_provider)
|
self.models: ClientModels = ClientModels(self, provider, media_provider)
|
||||||
self.images: AsyncImages = AsyncImages(self, media_provider)
|
self.images: AsyncImages = AsyncImages(self, media_provider)
|
||||||
self.media: AsyncImages = self.images
|
self.media: AsyncImages = self.images
|
||||||
self.responses: AsyncResponses = AsyncResponses(self, provider)
|
|
||||||
|
|
||||||
class AsyncChat:
|
class AsyncChat:
|
||||||
completions: AsyncCompletions
|
completions: AsyncCompletions
|
||||||
|
|
@ -682,6 +644,7 @@ class AsyncCompletions:
|
||||||
stop: Optional[Union[list[str], str]] = None,
|
stop: Optional[Union[list[str], str]] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
ignore_stream: Optional[bool] = False,
|
ignore_stream: Optional[bool] = False,
|
||||||
|
raw: Optional[bool] = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Awaitable[ChatCompletion]:
|
) -> Awaitable[ChatCompletion]:
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
|
|
@ -713,12 +676,21 @@ class AsyncCompletions:
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if raw:
|
||||||
|
async def filter_raw(response):
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk, JsonResponse):
|
||||||
|
yield chunk
|
||||||
|
raw_response = filter_raw(response)
|
||||||
|
if stream:
|
||||||
|
return raw_response
|
||||||
|
return next(raw_response)
|
||||||
|
|
||||||
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
||||||
response = async_iter_append_model_and_provider(response, model, provider)
|
response = async_iter_append_model_and_provider(response, model, provider)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
return anext(response)
|
return anext(response)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
|
|
@ -755,51 +727,3 @@ class AsyncImages(Images):
|
||||||
return await self.async_create_variation(
|
return await self.async_create_variation(
|
||||||
image=image, model=model, provider=provider, response_format=response_format, **kwargs
|
image=image, model=model, provider=provider, response_format=response_format, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
class AsyncResponses():
|
|
||||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
|
||||||
self.client: AsyncClient = client
|
|
||||||
self.provider: ProviderType = provider
|
|
||||||
|
|
||||||
async def create(
|
|
||||||
self,
|
|
||||||
input: str,
|
|
||||||
model: str = "",
|
|
||||||
provider: Optional[ProviderType] = None,
|
|
||||||
instructions: Optional[str] = None,
|
|
||||||
proxy: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> ClientResponse:
|
|
||||||
if isinstance(input, str):
|
|
||||||
input = [{"role": "user", "content": input}]
|
|
||||||
if instructions is not None:
|
|
||||||
input = [{"role": "developer", "content": instructions}] + input
|
|
||||||
for idx, message in enumerate(input):
|
|
||||||
if isinstance(message["content"], list):
|
|
||||||
for key, value in enumerate(message["content"]):
|
|
||||||
if isinstance(value, dict) and value.get("type") == "input_text":
|
|
||||||
message["content"][key] = {"type": "text", "text": value.get("text")}
|
|
||||||
input[idx] = {"role": message["role"], "content": message["content"]}
|
|
||||||
resolve_media(kwargs)
|
|
||||||
if hasattr(model, "name"):
|
|
||||||
model = model.get_long_name()
|
|
||||||
if provider is None:
|
|
||||||
provider = self.provider
|
|
||||||
if provider is None:
|
|
||||||
provider = AnyProvider
|
|
||||||
if isinstance(provider, str):
|
|
||||||
provider = convert_to_provider(provider)
|
|
||||||
|
|
||||||
response = async_iter_run_tools(
|
|
||||||
provider,
|
|
||||||
model=model,
|
|
||||||
messages=input,
|
|
||||||
**filter_none(
|
|
||||||
proxy=self.client.proxy if proxy is None else proxy,
|
|
||||||
api_key=self.client.api_key if api_key is None else api_key
|
|
||||||
),
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return await async_response(response)
|
|
||||||
|
|
|
||||||
|
|
@ -274,6 +274,10 @@ class Api:
|
||||||
yield self._format_json("continue", chunk.log)
|
yield self._format_json("continue", chunk.log)
|
||||||
elif isinstance(chunk, RawResponse):
|
elif isinstance(chunk, RawResponse):
|
||||||
yield self._format_json(chunk.type, **chunk.get_dict())
|
yield self._format_json(chunk.type, **chunk.get_dict())
|
||||||
|
elif isinstance(chunk, JsonRequest):
|
||||||
|
yield self._format_json("request", chunk.get_dict())
|
||||||
|
elif isinstance(chunk, JsonResponse):
|
||||||
|
yield self._format_json("response", chunk.get_dict())
|
||||||
else:
|
else:
|
||||||
yield self._format_json("content", str(chunk))
|
yield self._format_json("content", str(chunk))
|
||||||
except MissingAuthError as e:
|
except MissingAuthError as e:
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,33 @@ class JsonMixin:
|
||||||
class RawResponse(ResponseType, JsonMixin):
|
class RawResponse(ResponseType, JsonMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class ObjectMixin:
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
"""Initialize with keyword arguments as attributes."""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, ObjectMixin.from_dict(value) if isinstance(value, dict) else [ObjectMixin.from_dict(v) if isinstance(v, dict) else v for v in value] if isinstance(value, list) else value)
|
||||||
|
|
||||||
|
def get_dict(self) -> Dict:
|
||||||
|
"""Return a dictionary of non-private attributes."""
|
||||||
|
return {
|
||||||
|
key: value.get_dict() if isinstance(value, ObjectMixin) else [v.get_dict() if isinstance(v, ObjectMixin) else v for v in value] if isinstance(value, list) else value
|
||||||
|
for key, value in self.__dict__.items()
|
||||||
|
if not key.startswith("__")
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict) -> JsonMixin:
|
||||||
|
"""Create an instance from a dictionary."""
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
class JsonResponse(ResponseType, ObjectMixin):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.get_dict())
|
||||||
|
|
||||||
|
class JsonRequest(ResponseType, ObjectMixin):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.get_dict())
|
||||||
|
|
||||||
class HiddenResponse(ResponseType):
|
class HiddenResponse(ResponseType):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Hidden responses return an empty string."""
|
"""Hidden responses return an empty string."""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue