mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
Fix generate image in OpenaiChat
Add HarProvider, disable LMArenaProvider
This commit is contained in:
parent
eda3f69d4f
commit
9aba62733a
9 changed files with 1675 additions and 42 deletions
|
|
@ -24,10 +24,10 @@ from ...requests import StreamSession
|
|||
from ...requests import get_nodriver
|
||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, NoValidHarFileError
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
|
||||
from ...tools.media import merge_media
|
||||
from ..helper import format_cookies, get_last_user_message
|
||||
from ..helper import format_cookies, format_image_prompt
|
||||
from ..openai.models import default_model, default_image_model, models, image_models, text_models
|
||||
from ..openai.har_file import get_request_config
|
||||
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
|
||||
|
|
@ -254,31 +254,26 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
return messages
|
||||
|
||||
@classmethod
|
||||
async def get_generated_image(cls, session: StreamSession, auth_result: AuthResult, element: dict, prompt: str, conversation_id: str) -> ImageResponse:
|
||||
try:
|
||||
prompt = element["metadata"]["dalle"]["prompt"]
|
||||
except IndexError:
|
||||
pass
|
||||
try:
|
||||
file_id = element["asset_pointer"]
|
||||
if "file-service://" in file_id:
|
||||
file_id = file_id.split("file-service://", 1)[-1]
|
||||
url = f"{cls.url}/backend-api/files/{file_id}/download"
|
||||
else:
|
||||
file_id = file_id.split("sediment://")[-1]
|
||||
url = f"{cls.url}/backend-api/conversation/{conversation_id}/attachment/{file_id}/download"
|
||||
except TypeError:
|
||||
return
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"No Image: {element} - {e}")
|
||||
try:
|
||||
async def get_generated_images(cls, session: StreamSession, auth_result: AuthResult, parts: list, prompt: str, conversation_id: str) -> AsyncIterator:
|
||||
download_urls = []
|
||||
for element in [parts] if isinstance(parts, str) else parts:
|
||||
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
|
||||
if not prompt:
|
||||
prompt = element["metadata"]["dalle"]["prompt"]
|
||||
element = element["asset_pointer"]
|
||||
element = element.split("sediment://")[-1]
|
||||
url = f"{cls.url}/backend-api/conversation/{conversation_id}/attachment/{element}/download"
|
||||
debug.log(f"OpenaiChat: Downloading image: {url}")
|
||||
async with session.get(url, headers=auth_result.headers) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
download_url = (await response.json())["download_url"]
|
||||
return ImageResponse(download_url, prompt)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in downloading image: {e}")
|
||||
data = await response.json()
|
||||
download_url = data.get("download_url")
|
||||
if download_url is None:
|
||||
print(data)
|
||||
else:
|
||||
download_urls.append(download_url)
|
||||
return ImagePreview(download_urls, prompt)
|
||||
|
||||
@classmethod
|
||||
async def create_authed(
|
||||
|
|
@ -394,10 +389,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
#f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
|
||||
#f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
|
||||
)]
|
||||
if action is None or action == "variant" or action == "continue" and conversation.message_id is None:
|
||||
action = "next"
|
||||
data = {
|
||||
"action": action,
|
||||
"action": "next",
|
||||
"parent_message_id": conversation.message_id,
|
||||
"model": model,
|
||||
"timezone_offset_min":-60,
|
||||
|
|
@ -413,7 +406,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
if conversation.conversation_id is not None:
|
||||
data["conversation_id"] = conversation.conversation_id
|
||||
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
||||
prompt = get_last_user_message(messages) if prompt is None else prompt
|
||||
conversation.prompt = format_image_prompt(messages, prompt)
|
||||
if action != "continue":
|
||||
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
|
||||
conversation.parent_message_id = None
|
||||
|
|
@ -444,7 +437,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
await raise_for_status(response)
|
||||
buffer = u""
|
||||
async for line in response.iter_lines():
|
||||
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources, prompt):
|
||||
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources):
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
|
||||
buffer += chunk
|
||||
|
|
@ -469,6 +462,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
break
|
||||
if sources.list:
|
||||
yield sources
|
||||
if conversation.generated_images:
|
||||
yield ImageResponse(conversation.generated_images.urls, conversation.prompt)
|
||||
conversation.generated_images = None
|
||||
conversation.prompt = None
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
if auth_result.api_key is not None:
|
||||
|
|
@ -486,7 +483,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
yield FinishReason(conversation.finish_reason)
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources, prompt: str) -> AsyncIterator:
|
||||
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
|
||||
if not line.startswith(b"data: "):
|
||||
return
|
||||
elif line.startswith(b"data: [DONE]"):
|
||||
|
|
@ -519,6 +516,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
for m in v:
|
||||
if m.get("p") == "/message/content/parts/0" and fields.recipient == "all":
|
||||
yield m.get("v")
|
||||
elif m.get("p") == "/message/metadata/image_gen_title":
|
||||
fields.prompt = m.get("v")
|
||||
elif m.get("p") == "/message/content/parts/0/asset_pointer":
|
||||
fields.generated_images = await cls.get_generated_images(session, auth_result, m.get("v"), fields.prompt, fields.conversation_id)
|
||||
elif m.get("p") == "/message/metadata/search_result_groups":
|
||||
for entry in [p.get("entries") for p in m.get("v")]:
|
||||
for link in entry:
|
||||
|
|
@ -547,14 +548,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
fields.is_thinking = True
|
||||
yield Reasoning(status=m.get("metadata", {}).get("initial_text"))
|
||||
if c.get("content_type") == "multimodal_text":
|
||||
generated_images = []
|
||||
for element in c.get("parts"):
|
||||
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
|
||||
image = cls.get_generated_image(session, auth_result, element, prompt, fields.conversation_id)
|
||||
generated_images.append(image)
|
||||
for image_response in await asyncio.gather(*generated_images):
|
||||
if image_response is not None:
|
||||
yield image_response
|
||||
yield await cls.get_generated_images(session, auth_result, c.get("parts"), fields.prompt, fields.conversation_id)
|
||||
if m.get("author", {}).get("role") == "assistant":
|
||||
if fields.parent_message_id is None:
|
||||
fields.parent_message_id = v.get("message", {}).get("id")
|
||||
|
|
@ -738,6 +732,8 @@ class Conversation(JsonConversation):
|
|||
self.is_thinking = is_thinking
|
||||
self.p = None
|
||||
self.thoughts_summary = ""
|
||||
self.prompt = None
|
||||
self.generated_images: ImagePreview = None
|
||||
|
||||
def get_cookies(
|
||||
urls: Optional[Iterator[str]] = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue