fix: use ref_index and ref_type to match sources. add image, video, forecast references support

This commit is contained in:
GravityTwoG 2025-07-04 20:34:55 +07:00
parent 2fc26939db
commit 8d7a31a32c

View file

@ -8,7 +8,7 @@ import json
import base64
import time
import random
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union, List, Any
from copy import copy
try:
@ -24,8 +24,8 @@ from ...requests import StreamSession
from ...requests import get_nodriver
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview, ResponseType, format_link
from ...providers.response import TitleGeneration, RequestLogin, Reasoning
from ...tools.media import merge_media
from ..helper import format_cookies, format_media_prompt, to_string
from ..openai.models import default_model, default_image_model, models, image_models, text_models, model_aliases
@ -370,7 +370,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if cls._api_key is None:
auto_continue = False
conversation.finish_reason = None
sources = Sources([])
sources = OpenAISources([])
references = ContentReferences()
while conversation.finish_reason is None:
async with session.post(
f"{cls.url}/backend-anon/sentinel/chat-requirements"
@ -475,29 +476,99 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
generated_image = await cls.get_generated_image(session, auth_result, match.group(0), prompt)
if generated_image is not None:
yield generated_image
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources):
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources, references):
if isinstance(chunk, str):
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
buffer += chunk
if buffer.find(u"\ue200") != -1:
if buffer.find(u"\ue201") != -1:
buffer = buffer.replace("\ue200", "").replace("\ue202", "\n").replace("\ue201", "")
buffer = buffer.replace("navlist\n", "#### ")
def replacer(match):
link = None
if len(sources.list) > int(match.group(1)):
link = sources.list[int(match.group(1))]["url"]
return f"[[{int(match.group(1))+1}]]({link})"
return f" [{int(match.group(1))+1}]"
buffer = re.sub(r'(?:cite\nturn[0-9]+|turn[0-9]+)(?:search|news|view)(\d+)', replacer, buffer)
def sequence_replacer(match):
def citation_replacer(match: re.Match[str]):
ref_type = match.group(1)
ref_index = int(match.group(2))
if ((ref_type == "image" and is_image_embedding) or
is_video_embedding or
ref_type == "forecast"):
reference = references.get_reference({
"ref_index": ref_index,
"ref_type": ref_type
})
if not reference:
return ""
if ref_type == "forecast":
if reference.get("alt"):
return reference.get("alt")
if reference.get("prompt_text"):
return reference.get("prompt_text")
if is_image_embedding and reference.get("content_url", ""):
return f"![{reference.get('title', '')}]({reference.get('content_url')})"
if is_video_embedding:
if reference.get("url", "") and reference.get("thumbnail_url", ""):
return f"[![{reference.get('title', '')}]({reference['thumbnail_url']})]({reference['url']})"
video_match = re.match(r"video\n(.*?)\nturn[0-9]+", match.group(0))
if video_match:
return video_match.group(1)
return ""
source_index = sources.get_index({
"ref_index": ref_index,
"ref_type": ref_type
})
if source_index is not None and len(sources.list) > source_index:
link = sources.list[source_index]["url"]
return f"[[{source_index+1}]]({link})"
return f""
def products_replacer(match: re.Match[str]):
try:
products_data = json.loads(match.group(1))
products_str = ""
for idx, _ in enumerate(products_data.get("selections", []) or []):
name = products_data.get('selections', [])[idx][1]
tags = products_data.get('tags', [])[idx]
products_str += f"{name} - {tags}\n\n"
return products_str
except:
return ""
sequence_content = match.group(1)
sequence_content = sequence_content.replace("\ue200", "").replace("\ue202", "\n").replace("\ue201", "")
sequence_content = sequence_content.replace("navlist\n", "#### ")
# Handle search, news, view and image citations
is_image_embedding = sequence_content.startswith("i\nturn")
is_video_embedding = sequence_content.startswith("video\n")
sequence_content = re.sub(
r'(?:cite\nturn[0-9]+|forecast\nturn[0-9]+|video\n.*?\nturn[0-9]+|i?\n?turn[0-9]+)(search|news|view|image|forecast)(\d+)',
citation_replacer,
sequence_content
)
sequence_content = re.sub(r'products\n(.*)', products_replacer, sequence_content)
sequence_content = re.sub(r'product_entity\n\[".*","(.*)"\]', lambda x: x.group(1), sequence_content)
return sequence_content
# process only completed sequences and do not touch start of next not completed sequence
buffer = re.sub(r'\ue200(.*?)\ue201', sequence_replacer, buffer, flags=re.DOTALL)
if buffer.find(u"\ue200") != -1: # still have uncompleted sequence
continue
else:
# do not yield to consume rest part of special sequence
continue
yield buffer
buffer = ""
else:
yield chunk
if conversation.finish_reason is not None:
break
if buffer:
yield buffer
if sources.list:
yield sources
if conversation.generated_images:
@ -521,7 +592,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
yield FinishReason(conversation.finish_reason)
@classmethod
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: OpenAISources, references: ContentReferences) -> AsyncIterator:
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
@ -553,31 +624,84 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
elif "p" not in line or line.get("p") == "/message/content/parts/0":
yield Reasoning(token=v) if fields.is_thinking else v
elif isinstance(v, list):
buffer = ""
for m in v:
if m.get("p") == "/message/content/parts/0" and fields.recipient == "all":
yield m.get("v")
buffer += m.get("v")
elif m.get("p") == "/message/metadata/image_gen_title":
fields.prompt = m.get("v")
elif m.get("p") == "/message/content/parts/0/asset_pointer":
generated_images = fields.generated_images = await cls.get_generated_image(session, auth_result, m.get("v"), fields.prompt, fields.conversation_id)
if generated_images is not None:
if buffer:
yield buffer
yield generated_images
elif m.get("p") == "/message/metadata/search_result_groups":
for entry in [p.get("entries") for p in m.get("v")]:
for link in entry:
sources.add_source(link)
elif m.get("p") == "/message/metadata/content_references":
elif m.get("p") == "/message/metadata/content_references" and not isinstance(m.get("v"), int):
for entry in m.get("v"):
for link in entry.get("sources", []):
sources.add_source(link)
for link in entry.get("items", []):
sources.add_source(link)
for link in entry.get("fallback_items", []) or []:
sources.add_source(link)
if m.get("o", None) == "append":
references.add_reference(entry)
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+$", m.get("p")):
sources.add_source(m.get("v"))
if "url" in m.get("v") or "link" in m.get("v"):
sources.add_source(m.get("v"))
for link in m.get("v").get("fallback_items", []) or []:
sources.add_source(link)
match = re.match(r"^/message/metadata/content_references/(\d+)$", m.get("p"))
if match and m.get("o") == "append" and isinstance(m.get("v"), dict):
idx = int(match.group(1))
references.merge_reference(idx, m.get("v"))
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/fallback_items$", m.get("p")) and isinstance(m.get("v"), list):
for link in m.get("v", []) or []:
sources.add_source(link)
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/items$", m.get("p")) and isinstance(m.get("v"), list):
for link in m.get("v", []) or []:
sources.add_source(link)
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/refs$", m.get("p")) and isinstance(m.get("v"), list):
match = re.match(r"^/message/metadata/content_references/(\d+)/refs$", m.get("p"))
if match:
idx = int(match.group(1))
references.update_reference(idx, m.get("o"), "refs", m.get("v"))
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/alt$", m.get("p")) and isinstance(m.get("v"), list):
match = re.match(r"^/message/metadata/content_references/(\d+)/alt$", m.get("p"))
if match:
idx = int(match.group(1))
references.update_reference(idx, m.get("o"), "alt", m.get("v"))
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/prompt_text$", m.get("p")) and isinstance(m.get("v"), list):
match = re.match(r"^/message/metadata/content_references/(\d+)/prompt_text$", m.get("p"))
if match:
idx = int(match.group(1))
references.update_reference(idx, m.get("o"), "prompt_text", m.get("v"))
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/refs/\d+$", m.get("p")) and isinstance(m.get("v"), dict):
match = re.match(r"^/message/metadata/content_references/(\d+)/refs/(\d+)$", m.get("p"))
if match:
reference_idx = int(match.group(1))
ref_idx = int(match.group(2))
references.update_reference(reference_idx, m.get("o"), "refs", m.get("v"), ref_idx)
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/images$", m.get("p")) and isinstance(m.get("v"), list):
match = re.match(r"^/message/metadata/content_references/(\d+)/images$", m.get("p"))
if match:
idx = int(match.group(1))
references.update_reference(idx, m.get("o"), "images", m.get("v"))
elif m.get("p") == "/message/metadata/finished_text":
fields.is_thinking = False
if buffer:
yield buffer
yield Reasoning(status=m.get("v"))
elif m.get("p") == "/message/metadata" and fields.recipient == "all":
fields.finish_reason = m.get("v", {}).get("finish_details", {}).get("type")
break
yield buffer
elif isinstance(v, dict):
if fields.conversation_id is None:
fields.conversation_id = v.get("conversation_id")
@ -794,3 +918,166 @@ def get_cookies(
}
json = yield cmd_dict
return {c["name"]: c["value"] for c in json['cookies']} if 'cookies' in json else {}
class OpenAISources(ResponseType):
list: List[Dict[str, str]]
def __init__(self, sources: List[Dict[str, str]]) -> None:
"""Initialize with a list of source dictionaries."""
self.list = []
for source in sources:
self.add_source(source)
def add_source(self, source: Union[Dict[str, str], str]) -> None:
"""Add a source to the list, cleaning the URL if necessary."""
source = source if isinstance(source, dict) else {"url": source}
url = source.get("url", source.get("link", None))
if not url:
return
url = re.sub(r"[&?]utm_source=.+", "", url)
source["url"] = url
ref_info = self.get_ref_info(source)
if ref_info:
existing_source, idx = self.find_by_ref_info(ref_info)
if existing_source and idx is not None:
self.list[idx] = source
return
existing_source, idx = self.find_by_url(source["url"])
if existing_source and idx is not None:
self.list[idx] = source
return
self.list.append(source)
def __str__(self) -> str:
"""Return formatted sources as a string."""
if not self.list:
return ""
return "\n\n\n\n" + ("\n>\n".join([
f"> [{idx+1}] {format_link(link['url'], link.get('title', ''))}"
for idx, link in enumerate(self.list)
]))
def get_ref_info(self, source: Dict[str, str]) -> dict[str, str|int] | None:
ref_index = source.get("ref_id", {}).get("ref_index", None)
ref_type = source.get("ref_id", {}).get("ref_type", None)
if isinstance(ref_index, int):
return {
"ref_index": ref_index,
"ref_type": ref_type,
}
for ref_info in source.get('refs') or []:
ref_index = ref_info.get("ref_index", None)
ref_type = ref_info.get("ref_type", None)
if isinstance(ref_index, int):
return {
"ref_index": ref_index,
"ref_type": ref_type,
}
return None
def find_by_ref_info(self, ref_info: dict[str, str|int]):
for idx, source in enumerate(self.list):
source_ref_info = self.get_ref_info(source)
if (source_ref_info and
source_ref_info["ref_index"] == ref_info["ref_index"] and
source_ref_info["ref_type"] == ref_info["ref_type"]):
return source, idx
return None, None
def find_by_url(self, url: str):
for idx, source in enumerate(self.list):
if source["url"] == url:
return source, idx
return None, None
def get_index(self, ref_info: dict[str, str|int]) -> int | None:
_, index = self.find_by_ref_info(ref_info)
if index is not None:
return index
return None
class ContentReferences:
def __init__(self) -> None:
self.list: List[Dict[str, Any]] = []
def add_reference(self, reference_part: dict) -> None:
self.list.append(reference_part)
def merge_reference(self, idx: int, reference_part: dict):
while len(self.list) <= idx:
self.list.append({})
self.list[idx] = {**self.list[idx], **reference_part}
def update_reference(self, idx: int, operation: str, field: str, value: Any, ref_idx = None) -> None:
while len(self.list) <= idx:
self.list.append({})
if operation == "append" or operation == "add":
if not isinstance(self.list[idx].get(field, None), list):
self.list[idx][field] = []
if isinstance(value, list):
self.list[idx][field].extend(value)
else:
self.list[idx][field].append(value)
if operation == "replace" and ref_idx is not None:
if field == "refs" and not isinstance(self.list[idx].get(field, None), list):
self.list[idx][field] = []
if isinstance(self.list[idx][field], list):
if len(self.list[idx][field]) <= ref_idx:
self.list[idx][field].append(value)
else:
self.list[idx][field][ref_idx] = value
else:
self.list[idx][field] = value
def get_ref_info(
self,
source: Dict[str, str],
target_ref_info: Dict[str, Union[str, int]]
) -> dict[str, str|int] | None:
for idx, ref_info in enumerate(source.get("refs", [])) or []:
if not isinstance(ref_info, dict):
continue
ref_index = ref_info.get("ref_index", None)
ref_type = ref_info.get("ref_type", None)
if isinstance(ref_index, int) and isinstance(ref_type, str):
if (not target_ref_info or
(target_ref_info["ref_index"] == ref_index and
target_ref_info["ref_type"] == ref_type)):
return {
"ref_index": ref_index,
"ref_type": ref_type,
"idx": idx
}
return None
def get_reference(self, ref_info: Dict[str, Union[str, int]]) -> Any:
for reference in self.list:
reference_ref_info = self.get_ref_info(reference, ref_info)
if (not reference_ref_info or
reference_ref_info["ref_index"] != ref_info["ref_index"] or
reference_ref_info["ref_type"] != ref_info["ref_type"]):
continue
if ref_info["ref_type"] != "image":
return reference
images = reference.get("images", [])
if isinstance(images, list) and len(images) > reference_ref_info["idx"]:
return images[reference_ref_info["idx"]]
return None