mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
fix: use ref_index and ref_type to match sources. add image, video, forecast references support
This commit is contained in:
parent
2fc26939db
commit
8d7a31a32c
1 changed files with 305 additions and 18 deletions
|
|
@ -8,7 +8,7 @@ import json
|
|||
import base64
|
||||
import time
|
||||
import random
|
||||
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union
|
||||
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union, List, Any
|
||||
from copy import copy
|
||||
|
||||
try:
|
||||
|
|
@ -24,8 +24,8 @@ from ...requests import StreamSession
|
|||
from ...requests import get_nodriver
|
||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview, ResponseType, format_link
|
||||
from ...providers.response import TitleGeneration, RequestLogin, Reasoning
|
||||
from ...tools.media import merge_media
|
||||
from ..helper import format_cookies, format_media_prompt, to_string
|
||||
from ..openai.models import default_model, default_image_model, models, image_models, text_models, model_aliases
|
||||
|
|
@ -370,7 +370,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
if cls._api_key is None:
|
||||
auto_continue = False
|
||||
conversation.finish_reason = None
|
||||
sources = Sources([])
|
||||
sources = OpenAISources([])
|
||||
references = ContentReferences()
|
||||
while conversation.finish_reason is None:
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-anon/sentinel/chat-requirements"
|
||||
|
|
@ -475,29 +476,99 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
generated_image = await cls.get_generated_image(session, auth_result, match.group(0), prompt)
|
||||
if generated_image is not None:
|
||||
yield generated_image
|
||||
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources):
|
||||
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources, references):
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
|
||||
buffer += chunk
|
||||
if buffer.find(u"\ue200") != -1:
|
||||
if buffer.find(u"\ue201") != -1:
|
||||
buffer = buffer.replace("\ue200", "").replace("\ue202", "\n").replace("\ue201", "")
|
||||
buffer = buffer.replace("navlist\n", "#### ")
|
||||
def replacer(match):
|
||||
link = None
|
||||
if len(sources.list) > int(match.group(1)):
|
||||
link = sources.list[int(match.group(1))]["url"]
|
||||
return f"[[{int(match.group(1))+1}]]({link})"
|
||||
return f" [{int(match.group(1))+1}]"
|
||||
buffer = re.sub(r'(?:cite\nturn[0-9]+|turn[0-9]+)(?:search|news|view)(\d+)', replacer, buffer)
|
||||
def sequence_replacer(match):
|
||||
def citation_replacer(match: re.Match[str]):
|
||||
ref_type = match.group(1)
|
||||
ref_index = int(match.group(2))
|
||||
if ((ref_type == "image" and is_image_embedding) or
|
||||
is_video_embedding or
|
||||
ref_type == "forecast"):
|
||||
|
||||
reference = references.get_reference({
|
||||
"ref_index": ref_index,
|
||||
"ref_type": ref_type
|
||||
})
|
||||
if not reference:
|
||||
return ""
|
||||
|
||||
if ref_type == "forecast":
|
||||
if reference.get("alt"):
|
||||
return reference.get("alt")
|
||||
if reference.get("prompt_text"):
|
||||
return reference.get("prompt_text")
|
||||
|
||||
if is_image_embedding and reference.get("content_url", ""):
|
||||
return f"})"
|
||||
|
||||
if is_video_embedding:
|
||||
if reference.get("url", "") and reference.get("thumbnail_url", ""):
|
||||
return f"[]({reference['url']})"
|
||||
video_match = re.match(r"video\n(.*?)\nturn[0-9]+", match.group(0))
|
||||
if video_match:
|
||||
return video_match.group(1)
|
||||
return ""
|
||||
|
||||
source_index = sources.get_index({
|
||||
"ref_index": ref_index,
|
||||
"ref_type": ref_type
|
||||
})
|
||||
if source_index is not None and len(sources.list) > source_index:
|
||||
link = sources.list[source_index]["url"]
|
||||
return f"[[{source_index+1}]]({link})"
|
||||
return f""
|
||||
|
||||
def products_replacer(match: re.Match[str]):
|
||||
try:
|
||||
products_data = json.loads(match.group(1))
|
||||
products_str = ""
|
||||
for idx, _ in enumerate(products_data.get("selections", []) or []):
|
||||
name = products_data.get('selections', [])[idx][1]
|
||||
tags = products_data.get('tags', [])[idx]
|
||||
products_str += f"{name} - {tags}\n\n"
|
||||
|
||||
return products_str
|
||||
except:
|
||||
return ""
|
||||
|
||||
sequence_content = match.group(1)
|
||||
sequence_content = sequence_content.replace("\ue200", "").replace("\ue202", "\n").replace("\ue201", "")
|
||||
sequence_content = sequence_content.replace("navlist\n", "#### ")
|
||||
|
||||
# Handle search, news, view and image citations
|
||||
is_image_embedding = sequence_content.startswith("i\nturn")
|
||||
is_video_embedding = sequence_content.startswith("video\n")
|
||||
sequence_content = re.sub(
|
||||
r'(?:cite\nturn[0-9]+|forecast\nturn[0-9]+|video\n.*?\nturn[0-9]+|i?\n?turn[0-9]+)(search|news|view|image|forecast)(\d+)',
|
||||
citation_replacer,
|
||||
sequence_content
|
||||
)
|
||||
sequence_content = re.sub(r'products\n(.*)', products_replacer, sequence_content)
|
||||
sequence_content = re.sub(r'product_entity\n\[".*","(.*)"\]', lambda x: x.group(1), sequence_content)
|
||||
return sequence_content
|
||||
|
||||
# process only completed sequences and do not touch start of next not completed sequence
|
||||
buffer = re.sub(r'\ue200(.*?)\ue201', sequence_replacer, buffer, flags=re.DOTALL)
|
||||
|
||||
if buffer.find(u"\ue200") != -1: # still have uncompleted sequence
|
||||
continue
|
||||
else:
|
||||
# do not yield to consume rest part of special sequence
|
||||
continue
|
||||
|
||||
yield buffer
|
||||
buffer = ""
|
||||
else:
|
||||
yield chunk
|
||||
if conversation.finish_reason is not None:
|
||||
break
|
||||
if buffer:
|
||||
yield buffer
|
||||
if sources.list:
|
||||
yield sources
|
||||
if conversation.generated_images:
|
||||
|
|
@ -521,7 +592,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
yield FinishReason(conversation.finish_reason)
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
|
||||
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: OpenAISources, references: ContentReferences) -> AsyncIterator:
|
||||
if not line.startswith(b"data: "):
|
||||
return
|
||||
elif line.startswith(b"data: [DONE]"):
|
||||
|
|
@ -553,31 +624,84 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
elif "p" not in line or line.get("p") == "/message/content/parts/0":
|
||||
yield Reasoning(token=v) if fields.is_thinking else v
|
||||
elif isinstance(v, list):
|
||||
buffer = ""
|
||||
for m in v:
|
||||
if m.get("p") == "/message/content/parts/0" and fields.recipient == "all":
|
||||
yield m.get("v")
|
||||
buffer += m.get("v")
|
||||
elif m.get("p") == "/message/metadata/image_gen_title":
|
||||
fields.prompt = m.get("v")
|
||||
elif m.get("p") == "/message/content/parts/0/asset_pointer":
|
||||
generated_images = fields.generated_images = await cls.get_generated_image(session, auth_result, m.get("v"), fields.prompt, fields.conversation_id)
|
||||
if generated_images is not None:
|
||||
if buffer:
|
||||
yield buffer
|
||||
yield generated_images
|
||||
elif m.get("p") == "/message/metadata/search_result_groups":
|
||||
for entry in [p.get("entries") for p in m.get("v")]:
|
||||
for link in entry:
|
||||
sources.add_source(link)
|
||||
elif m.get("p") == "/message/metadata/content_references":
|
||||
elif m.get("p") == "/message/metadata/content_references" and not isinstance(m.get("v"), int):
|
||||
for entry in m.get("v"):
|
||||
for link in entry.get("sources", []):
|
||||
sources.add_source(link)
|
||||
for link in entry.get("items", []):
|
||||
sources.add_source(link)
|
||||
for link in entry.get("fallback_items", []) or []:
|
||||
sources.add_source(link)
|
||||
if m.get("o", None) == "append":
|
||||
references.add_reference(entry)
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+$", m.get("p")):
|
||||
sources.add_source(m.get("v"))
|
||||
if "url" in m.get("v") or "link" in m.get("v"):
|
||||
sources.add_source(m.get("v"))
|
||||
for link in m.get("v").get("fallback_items", []) or []:
|
||||
sources.add_source(link)
|
||||
|
||||
match = re.match(r"^/message/metadata/content_references/(\d+)$", m.get("p"))
|
||||
if match and m.get("o") == "append" and isinstance(m.get("v"), dict):
|
||||
idx = int(match.group(1))
|
||||
references.merge_reference(idx, m.get("v"))
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/fallback_items$", m.get("p")) and isinstance(m.get("v"), list):
|
||||
for link in m.get("v", []) or []:
|
||||
sources.add_source(link)
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/items$", m.get("p")) and isinstance(m.get("v"), list):
|
||||
for link in m.get("v", []) or []:
|
||||
sources.add_source(link)
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/refs$", m.get("p")) and isinstance(m.get("v"), list):
|
||||
match = re.match(r"^/message/metadata/content_references/(\d+)/refs$", m.get("p"))
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
references.update_reference(idx, m.get("o"), "refs", m.get("v"))
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/alt$", m.get("p")) and isinstance(m.get("v"), list):
|
||||
match = re.match(r"^/message/metadata/content_references/(\d+)/alt$", m.get("p"))
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
references.update_reference(idx, m.get("o"), "alt", m.get("v"))
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/prompt_text$", m.get("p")) and isinstance(m.get("v"), list):
|
||||
match = re.match(r"^/message/metadata/content_references/(\d+)/prompt_text$", m.get("p"))
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
references.update_reference(idx, m.get("o"), "prompt_text", m.get("v"))
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/refs/\d+$", m.get("p")) and isinstance(m.get("v"), dict):
|
||||
match = re.match(r"^/message/metadata/content_references/(\d+)/refs/(\d+)$", m.get("p"))
|
||||
if match:
|
||||
reference_idx = int(match.group(1))
|
||||
ref_idx = int(match.group(2))
|
||||
references.update_reference(reference_idx, m.get("o"), "refs", m.get("v"), ref_idx)
|
||||
elif m.get("p") and re.match(r"^/message/metadata/content_references/\d+/images$", m.get("p")) and isinstance(m.get("v"), list):
|
||||
match = re.match(r"^/message/metadata/content_references/(\d+)/images$", m.get("p"))
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
references.update_reference(idx, m.get("o"), "images", m.get("v"))
|
||||
elif m.get("p") == "/message/metadata/finished_text":
|
||||
fields.is_thinking = False
|
||||
if buffer:
|
||||
yield buffer
|
||||
yield Reasoning(status=m.get("v"))
|
||||
elif m.get("p") == "/message/metadata" and fields.recipient == "all":
|
||||
fields.finish_reason = m.get("v", {}).get("finish_details", {}).get("type")
|
||||
break
|
||||
|
||||
yield buffer
|
||||
elif isinstance(v, dict):
|
||||
if fields.conversation_id is None:
|
||||
fields.conversation_id = v.get("conversation_id")
|
||||
|
|
@ -794,3 +918,166 @@ def get_cookies(
|
|||
}
|
||||
json = yield cmd_dict
|
||||
return {c["name"]: c["value"] for c in json['cookies']} if 'cookies' in json else {}
|
||||
|
||||
class OpenAISources(ResponseType):
|
||||
list: List[Dict[str, str]]
|
||||
|
||||
def __init__(self, sources: List[Dict[str, str]]) -> None:
|
||||
"""Initialize with a list of source dictionaries."""
|
||||
self.list = []
|
||||
for source in sources:
|
||||
self.add_source(source)
|
||||
|
||||
def add_source(self, source: Union[Dict[str, str], str]) -> None:
|
||||
"""Add a source to the list, cleaning the URL if necessary."""
|
||||
source = source if isinstance(source, dict) else {"url": source}
|
||||
url = source.get("url", source.get("link", None))
|
||||
if not url:
|
||||
return
|
||||
|
||||
url = re.sub(r"[&?]utm_source=.+", "", url)
|
||||
source["url"] = url
|
||||
|
||||
ref_info = self.get_ref_info(source)
|
||||
if ref_info:
|
||||
existing_source, idx = self.find_by_ref_info(ref_info)
|
||||
if existing_source and idx is not None:
|
||||
self.list[idx] = source
|
||||
return
|
||||
|
||||
existing_source, idx = self.find_by_url(source["url"])
|
||||
if existing_source and idx is not None:
|
||||
self.list[idx] = source
|
||||
return
|
||||
|
||||
self.list.append(source)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return formatted sources as a string."""
|
||||
if not self.list:
|
||||
return ""
|
||||
return "\n\n\n\n" + ("\n>\n".join([
|
||||
f"> [{idx+1}] {format_link(link['url'], link.get('title', ''))}"
|
||||
for idx, link in enumerate(self.list)
|
||||
]))
|
||||
|
||||
def get_ref_info(self, source: Dict[str, str]) -> dict[str, str|int] | None:
|
||||
ref_index = source.get("ref_id", {}).get("ref_index", None)
|
||||
ref_type = source.get("ref_id", {}).get("ref_type", None)
|
||||
if isinstance(ref_index, int):
|
||||
return {
|
||||
"ref_index": ref_index,
|
||||
"ref_type": ref_type,
|
||||
}
|
||||
|
||||
for ref_info in source.get('refs') or []:
|
||||
ref_index = ref_info.get("ref_index", None)
|
||||
ref_type = ref_info.get("ref_type", None)
|
||||
if isinstance(ref_index, int):
|
||||
return {
|
||||
"ref_index": ref_index,
|
||||
"ref_type": ref_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def find_by_ref_info(self, ref_info: dict[str, str|int]):
|
||||
for idx, source in enumerate(self.list):
|
||||
source_ref_info = self.get_ref_info(source)
|
||||
if (source_ref_info and
|
||||
source_ref_info["ref_index"] == ref_info["ref_index"] and
|
||||
source_ref_info["ref_type"] == ref_info["ref_type"]):
|
||||
return source, idx
|
||||
|
||||
return None, None
|
||||
|
||||
def find_by_url(self, url: str):
|
||||
for idx, source in enumerate(self.list):
|
||||
if source["url"] == url:
|
||||
return source, idx
|
||||
return None, None
|
||||
|
||||
def get_index(self, ref_info: dict[str, str|int]) -> int | None:
|
||||
_, index = self.find_by_ref_info(ref_info)
|
||||
if index is not None:
|
||||
return index
|
||||
|
||||
return None
|
||||
|
||||
class ContentReferences:
|
||||
def __init__(self) -> None:
|
||||
self.list: List[Dict[str, Any]] = []
|
||||
|
||||
def add_reference(self, reference_part: dict) -> None:
|
||||
self.list.append(reference_part)
|
||||
|
||||
def merge_reference(self, idx: int, reference_part: dict):
|
||||
while len(self.list) <= idx:
|
||||
self.list.append({})
|
||||
|
||||
self.list[idx] = {**self.list[idx], **reference_part}
|
||||
|
||||
def update_reference(self, idx: int, operation: str, field: str, value: Any, ref_idx = None) -> None:
|
||||
while len(self.list) <= idx:
|
||||
self.list.append({})
|
||||
|
||||
if operation == "append" or operation == "add":
|
||||
if not isinstance(self.list[idx].get(field, None), list):
|
||||
self.list[idx][field] = []
|
||||
if isinstance(value, list):
|
||||
self.list[idx][field].extend(value)
|
||||
else:
|
||||
self.list[idx][field].append(value)
|
||||
|
||||
if operation == "replace" and ref_idx is not None:
|
||||
if field == "refs" and not isinstance(self.list[idx].get(field, None), list):
|
||||
self.list[idx][field] = []
|
||||
|
||||
if isinstance(self.list[idx][field], list):
|
||||
if len(self.list[idx][field]) <= ref_idx:
|
||||
self.list[idx][field].append(value)
|
||||
else:
|
||||
self.list[idx][field][ref_idx] = value
|
||||
else:
|
||||
self.list[idx][field] = value
|
||||
|
||||
def get_ref_info(
|
||||
self,
|
||||
source: Dict[str, str],
|
||||
target_ref_info: Dict[str, Union[str, int]]
|
||||
) -> dict[str, str|int] | None:
|
||||
for idx, ref_info in enumerate(source.get("refs", [])) or []:
|
||||
if not isinstance(ref_info, dict):
|
||||
continue
|
||||
|
||||
ref_index = ref_info.get("ref_index", None)
|
||||
ref_type = ref_info.get("ref_type", None)
|
||||
if isinstance(ref_index, int) and isinstance(ref_type, str):
|
||||
if (not target_ref_info or
|
||||
(target_ref_info["ref_index"] == ref_index and
|
||||
target_ref_info["ref_type"] == ref_type)):
|
||||
return {
|
||||
"ref_index": ref_index,
|
||||
"ref_type": ref_type,
|
||||
"idx": idx
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def get_reference(self, ref_info: Dict[str, Union[str, int]]) -> Any:
|
||||
for reference in self.list:
|
||||
reference_ref_info = self.get_ref_info(reference, ref_info)
|
||||
|
||||
if (not reference_ref_info or
|
||||
reference_ref_info["ref_index"] != ref_info["ref_index"] or
|
||||
reference_ref_info["ref_type"] != ref_info["ref_type"]):
|
||||
continue
|
||||
|
||||
if ref_info["ref_type"] != "image":
|
||||
return reference
|
||||
|
||||
images = reference.get("images", [])
|
||||
if isinstance(images, list) and len(images) > reference_ref_info["idx"]:
|
||||
return images[reference_ref_info["idx"]]
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue