Enhance Perplexity provider to yield additional response types including sources, media items, and suggested follow-ups; update response formatting in response classes for improved data handling.

This commit is contained in:
hlohaus 2025-10-31 15:48:26 +01:00
parent 23218c4aa3
commit 35e3fa95f3
3 changed files with 26 additions and 7 deletions

View file

@ -6,7 +6,7 @@ import uuid
from ..typing import AsyncResult, Messages, Cookies from ..typing import AsyncResult, Messages, Cookies
from ..requests import StreamSession, raise_for_status, sse_stream from ..requests import StreamSession, raise_for_status, sse_stream
from ..cookies import get_cookies from ..cookies import get_cookies
from ..providers.response import ProviderInfo, JsonConversation, JsonRequest, JsonResponse, Reasoning from ..providers.response import ProviderInfo, JsonConversation, JsonRequest, JsonResponse, Reasoning, Sources, SuggestedFollowups, ImageResponse, PreviewResponse, YouTubeResponse
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .. import debug from .. import debug
@ -254,6 +254,19 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
async for json_data in sse_stream(response): async for json_data in sse_stream(response):
yield JsonResponse.from_dict(json_data) yield JsonResponse.from_dict(json_data)
for block in json_data.get("blocks", []): for block in json_data.get("blocks", []):
if block.get("intended_usage") == "sources_answer_mode":
yield Sources(block.get("sources_mode_block", {}).get("web_results", []))
continue
if block.get("intended_usage") == "media_items":
yield PreviewResponse([
ImageResponse(item.get("url"), item.get("name"), {
"height": item.get("image_height"),
"width": item.get("image_width"),
**item
}) if item.get("medium") == "image" else YouTubeResponse(item.get("url").split("=").pop())
for item in block.get("media_block", {}).get("media_items", [])
])
continue
for patch in block.get("diff_block", {}).get("patches", []): for patch in block.get("diff_block", {}).get("patches", []):
if patch.get("path") == "/progress": if patch.get("path") == "/progress":
continue continue
@ -278,3 +291,8 @@ class Perplexity(AsyncGeneratorProvider, ProviderModelMixin):
if value: if value:
full_response += value full_response += value
yield value yield value
if "related_query_items" in json_data:
followups = []
for item in json_data["related_query_items"]:
followups.append(item.get("text", ""))
yield SuggestedFollowups(followups)

View file

@ -25,6 +25,7 @@ from starlette.status import (
HTTP_404_NOT_FOUND, HTTP_404_NOT_FOUND,
HTTP_401_UNAUTHORIZED, HTTP_401_UNAUTHORIZED,
HTTP_403_FORBIDDEN, HTTP_403_FORBIDDEN,
HTTP_429_TOO_MANY_REQUESTS,
HTTP_500_INTERNAL_SERVER_ERROR, HTTP_500_INTERNAL_SERVER_ERROR,
) )
from starlette.staticfiles import NotModifiedResponse from starlette.staticfiles import NotModifiedResponse
@ -442,7 +443,7 @@ class Api:
current_most_wanted = next(iter(most_wanted.values()), 0) current_most_wanted = next(iter(most_wanted.values()), 0)
is_most_wanted = False is_most_wanted = False
if x_forwarded_for in most_wanted: if x_forwarded_for in most_wanted:
if failure_counts.get(x_forwarded_for, 0) > 0: if failure_counts.get(x_forwarded_for, 0) > 1:
failure_counts[x_forwarded_for] -= 1 failure_counts[x_forwarded_for] -= 1
most_wanted[x_forwarded_for] += 1 most_wanted[x_forwarded_for] += 1
elif most_wanted[x_forwarded_for] >= current_most_wanted: elif most_wanted[x_forwarded_for] >= current_most_wanted:
@ -457,7 +458,7 @@ class Api:
sorted_most_wanted = dict(sorted(most_wanted.items(), key=lambda item: item[1], reverse=True)) sorted_most_wanted = dict(sorted(most_wanted.items(), key=lambda item: item[1], reverse=True))
debug.log(f"Most wanted IPs: {sorted_most_wanted}") debug.log(f"Most wanted IPs: {sorted_most_wanted}")
if is_most_wanted: if is_most_wanted:
raise RateLimitError("You are most wanted! Please wait before making another request.") return ErrorResponse.from_message("You are most wanted! Please wait before making another request.", status_code=HTTP_429_TOO_MANY_REQUESTS)
if provider is not None and provider not in Provider.__map__: if provider is not None and provider not in Provider.__map__:
if provider in model_map: if provider in model_map:
config.model = provider config.model = provider

View file

@ -303,7 +303,7 @@ class Sources(ResponseType):
if not self.list: if not self.list:
return "" return ""
return "\n\n\n\n" + ("\n>\n".join([ return "\n\n\n\n" + ("\n>\n".join([
f"> [{idx}] {format_link(link['url'], link.get('title', None))}" f"> [{idx}] {format_link(link['url'], link.get('title', link.get('name', None)))}"
for idx, link in enumerate(self.list) for idx, link in enumerate(self.list)
])) ]))
@ -413,8 +413,8 @@ class ImageResponse(MediaResponse):
"""Return images as markdown.""" """Return images as markdown."""
if self.get("width") and self.get("height"): if self.get("width") and self.get("height"):
return "\n".join([ return "\n".join([
f'<a href="{html.escape(url)}" data-width="{self.get("width")}" data-height="{self.get("height")}" data-source="{html.escape(self.get("source_url", ""))}">' f'<a href="{html.escape(url)}" data-src="{self.get("image", url)}" data-width="{self.get("width")}" data-height="{self.get("height")}" data-source="{html.escape(self.get("source_url", ""))}">'
+ f'<img src="{url.replace("/media/", "/thumbnail/")}" alt="{html.escape(" ".join(self.alt.split()))}"></a>' + f'<img src="{self.get("thumbnail", url.replace("/media/", "/thumbnail/"))}" alt="{html.escape(self.alt)}" width="{html.escape(str(self.get("thumbnail_width", "")))}" height="{html.escape(str(self.get("thumbnail_height", "")))}"></a>'
for url in self.get_list() for url in self.get_list()
]) ])
return format_images_markdown(self.urls, self.alt, self.get("preview")) return format_images_markdown(self.urls, self.alt, self.get("preview"))
@ -442,7 +442,7 @@ class PreviewResponse(HiddenResponse):
def to_string(self) -> str: def to_string(self) -> str:
"""Return data as a string.""" """Return data as a string."""
return self.data return "".join([str(item) for item in self.data]) if isinstance(self.data, list) else str(self.data)
class Parameters(ResponseType, JsonMixin): class Parameters(ResponseType, JsonMixin):
def __str__(self) -> str: def __str__(self) -> str: