Fix generate image in OpenaiChat

Add HarProvider, disable LMArenaProvider
This commit is contained in:
hlohaus 2025-04-23 02:52:43 +02:00
parent eda3f69d4f
commit 9aba62733a
9 changed files with 1675 additions and 42 deletions

View file

@ -24,10 +24,10 @@ from ...requests import StreamSession
from ...requests import get_nodriver
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
from ...tools.media import merge_media
from ..helper import format_cookies, get_last_user_message
from ..helper import format_cookies, format_image_prompt
from ..openai.models import default_model, default_image_model, models, image_models, text_models
from ..openai.har_file import get_request_config
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
@ -254,31 +254,26 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
return messages
@classmethod
async def get_generated_image(cls, session: StreamSession, auth_result: AuthResult, element: dict, prompt: str, conversation_id: str) -> ImageResponse:
try:
prompt = element["metadata"]["dalle"]["prompt"]
except IndexError:
pass
try:
file_id = element["asset_pointer"]
if "file-service://" in file_id:
file_id = file_id.split("file-service://", 1)[-1]
url = f"{cls.url}/backend-api/files/{file_id}/download"
else:
file_id = file_id.split("sediment://")[-1]
url = f"{cls.url}/backend-api/conversation/{conversation_id}/attachment/{file_id}/download"
except TypeError:
return
except Exception as e:
raise RuntimeError(f"No Image: {element} - {e}")
try:
async def get_generated_images(cls, session: StreamSession, auth_result: AuthResult, parts: list, prompt: str, conversation_id: str) -> AsyncIterator:
download_urls = []
for element in [parts] if isinstance(parts, str) else parts:
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
if not prompt:
prompt = element["metadata"]["dalle"]["prompt"]
element = element["asset_pointer"]
element = element.split("sediment://")[-1]
url = f"{cls.url}/backend-api/conversation/{conversation_id}/attachment/{element}/download"
debug.log(f"OpenaiChat: Downloading image: {url}")
async with session.get(url, headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
download_url = (await response.json())["download_url"]
return ImageResponse(download_url, prompt)
except Exception as e:
raise RuntimeError(f"Error in downloading image: {e}")
data = await response.json()
download_url = data.get("download_url")
if download_url is None:
print(data)
else:
download_urls.append(download_url)
return ImagePreview(download_urls, prompt)
@classmethod
async def create_authed(
@ -394,10 +389,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
#f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
#f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
)]
if action is None or action == "variant" or action == "continue" and conversation.message_id is None:
action = "next"
data = {
"action": action,
"action": "next",
"parent_message_id": conversation.message_id,
"model": model,
"timezone_offset_min":-60,
@ -413,7 +406,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if conversation.conversation_id is not None:
data["conversation_id"] = conversation.conversation_id
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
prompt = get_last_user_message(messages) if prompt is None else prompt
conversation.prompt = format_image_prompt(messages, prompt)
if action != "continue":
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
conversation.parent_message_id = None
@ -444,7 +437,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await raise_for_status(response)
buffer = u""
async for line in response.iter_lines():
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources, prompt):
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources):
if isinstance(chunk, str):
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
buffer += chunk
@ -469,6 +462,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
break
if sources.list:
yield sources
if conversation.generated_images:
yield ImageResponse(conversation.generated_images.urls, conversation.prompt)
conversation.generated_images = None
conversation.prompt = None
if return_conversation:
yield conversation
if auth_result.api_key is not None:
@ -486,7 +483,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
yield FinishReason(conversation.finish_reason)
@classmethod
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources, prompt: str) -> AsyncIterator:
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
@ -519,6 +516,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
for m in v:
if m.get("p") == "/message/content/parts/0" and fields.recipient == "all":
yield m.get("v")
elif m.get("p") == "/message/metadata/image_gen_title":
fields.prompt = m.get("v")
elif m.get("p") == "/message/content/parts/0/asset_pointer":
fields.generated_images = await cls.get_generated_images(session, auth_result, m.get("v"), fields.prompt, fields.conversation_id)
elif m.get("p") == "/message/metadata/search_result_groups":
for entry in [p.get("entries") for p in m.get("v")]:
for link in entry:
@ -547,14 +548,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
fields.is_thinking = True
yield Reasoning(status=m.get("metadata", {}).get("initial_text"))
if c.get("content_type") == "multimodal_text":
generated_images = []
for element in c.get("parts"):
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
image = cls.get_generated_image(session, auth_result, element, prompt, fields.conversation_id)
generated_images.append(image)
for image_response in await asyncio.gather(*generated_images):
if image_response is not None:
yield image_response
yield await cls.get_generated_images(session, auth_result, c.get("parts"), fields.prompt, fields.conversation_id)
if m.get("author", {}).get("role") == "assistant":
if fields.parent_message_id is None:
fields.parent_message_id = v.get("message", {}).get("id")
@ -738,6 +732,8 @@ class Conversation(JsonConversation):
self.is_thinking = is_thinking
self.p = None
self.thoughts_summary = ""
self.prompt = None
self.generated_images: ImagePreview = None
def get_cookies(
urls: Optional[Iterator[str]] = None