mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 10:40:43 -08:00
add upload_file
- add upload_file - add conversation_mode - add temporary
This commit is contained in:
parent
9423a23003
commit
fe24210db2
2 changed files with 195 additions and 36 deletions
|
|
@ -22,7 +22,7 @@ from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...requests import StreamSession
|
from ...requests import StreamSession
|
||||||
from ...requests import get_nodriver
|
from ...requests import get_nodriver
|
||||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
|
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format, detect_file_type
|
||||||
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
|
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
|
||||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview, ResponseType, format_link
|
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview, ResponseType, format_link
|
||||||
from ...providers.response import TitleGeneration, RequestLogin, Reasoning
|
from ...providers.response import TitleGeneration, RequestLogin, Reasoning
|
||||||
|
|
@ -126,54 +126,66 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def upload_images(
|
async def upload_files(
|
||||||
cls,
|
cls,
|
||||||
session: StreamSession,
|
session: StreamSession,
|
||||||
auth_result: AuthResult,
|
auth_result: AuthResult,
|
||||||
media: MediaListType,
|
media: MediaListType,
|
||||||
) -> ImageRequest:
|
) -> list[ImageRequest]:
|
||||||
"""
|
"""
|
||||||
Upload an image to the service and get the download URL
|
Upload an image to the service and get the download URL
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: The StreamSession object to use for requests
|
session: The StreamSession object to use for requests
|
||||||
headers: The headers to include in the requests
|
headers: The headers to include in the requests
|
||||||
media: The images to upload, either a PIL Image object or a bytes object
|
media: The files to upload, either a PIL Image object or a bytes object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An ImageRequest object that contains the download URL, file name, and other data
|
An ImageRequest object that contains the download URL, file name, and other data
|
||||||
"""
|
"""
|
||||||
async def upload_image(image, image_name):
|
async def upload_file(file, image_name=None):
|
||||||
debug.log(f"Uploading image: {image_name}")
|
debug.log(f"Uploading file: {image_name}")
|
||||||
# Convert the image to a PIL Image object and get the extension
|
file_data = {}
|
||||||
data_bytes = to_bytes(image)
|
|
||||||
image = to_image(data_bytes)
|
data_bytes = to_bytes(file)
|
||||||
extension = image.format.lower()
|
extension, mime_type = detect_file_type(data_bytes)
|
||||||
|
if "image" in mime_type:
|
||||||
|
# Convert the image to a PIL Image object
|
||||||
|
file = to_image(data_bytes)
|
||||||
|
use_case = "multimodal"
|
||||||
|
file_data.update({"height": file.height, "width": file.width})
|
||||||
|
else:
|
||||||
|
use_case = "my_files"
|
||||||
|
image_name = (
|
||||||
|
f"file-{len(data_bytes)}{extension}"
|
||||||
|
if image_name is None
|
||||||
|
else image_name
|
||||||
|
)
|
||||||
data = {
|
data = {
|
||||||
"file_name": "" if image_name is None else image_name,
|
"file_name": image_name,
|
||||||
"file_size": len(data_bytes),
|
"file_size": len(data_bytes),
|
||||||
"use_case": "multimodal"
|
"use_case": use_case,
|
||||||
}
|
}
|
||||||
# Post the image data to the service and get the image data
|
# Post the image data to the service and get the image data
|
||||||
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=cls._headers) as response:
|
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=cls._headers) as response:
|
||||||
cls._update_request_args(auth_result, session)
|
cls._update_request_args(auth_result, session)
|
||||||
await raise_for_status(response, "Create file failed")
|
await raise_for_status(response, "Create file failed")
|
||||||
image_data = {
|
file_data.update(
|
||||||
**data,
|
{
|
||||||
**await response.json(),
|
**data,
|
||||||
"mime_type": is_accepted_format(data_bytes),
|
**await response.json(),
|
||||||
"extension": extension,
|
"mime_type": mime_type,
|
||||||
"height": image.height,
|
"extension": extension,
|
||||||
"width": image.width
|
}
|
||||||
}
|
)
|
||||||
# Put the image bytes to the upload URL and check the status
|
# Put the image bytes to the upload URL and check the status
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
async with session.put(
|
async with session.put(
|
||||||
image_data["upload_url"],
|
file_data["upload_url"],
|
||||||
data=data_bytes,
|
data=data_bytes,
|
||||||
headers={
|
headers={
|
||||||
**UPLOAD_HEADERS,
|
**UPLOAD_HEADERS,
|
||||||
"Content-Type": image_data["mime_type"],
|
"Content-Type": file_data["mime_type"],
|
||||||
"x-ms-blob-type": "BlockBlob",
|
"x-ms-blob-type": "BlockBlob",
|
||||||
"x-ms-version": "2020-04-08",
|
"x-ms-version": "2020-04-08",
|
||||||
"Origin": "https://chatgpt.com",
|
"Origin": "https://chatgpt.com",
|
||||||
|
|
@ -182,15 +194,22 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
# Post the file ID to the service and get the download URL
|
# Post the file ID to the service and get the download URL
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
|
f"{cls.url}/backend-api/files/{file_data['file_id']}/uploaded",
|
||||||
json={},
|
json={},
|
||||||
headers=auth_result.headers
|
headers=auth_result.headers
|
||||||
) as response:
|
) as response:
|
||||||
cls._update_request_args(auth_result, session)
|
cls._update_request_args(auth_result, session)
|
||||||
await raise_for_status(response, "Get download url failed")
|
await raise_for_status(response, "Get download url failed")
|
||||||
image_data["download_url"] = (await response.json())["download_url"]
|
uploaded_data = await response.json()
|
||||||
return ImageRequest(image_data)
|
file_data["download_url"] = uploaded_data["download_url"]
|
||||||
return [await upload_image(image, image_name) for image, image_name in media]
|
return ImageRequest(file_data)
|
||||||
|
|
||||||
|
medias = []
|
||||||
|
for item in media:
|
||||||
|
item = item if isinstance(item, tuple) else (item,)
|
||||||
|
__uploaded_media = await upload_file(*item)
|
||||||
|
medias.append(__uploaded_media)
|
||||||
|
return medias
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
|
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
|
||||||
|
|
@ -237,18 +256,27 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
"size_bytes": image_request.get("file_size"),
|
"size_bytes": image_request.get("file_size"),
|
||||||
"width": image_request.get("width"),
|
"width": image_request.get("width"),
|
||||||
}
|
}
|
||||||
for image_request in image_requests],
|
for image_request in image_requests
|
||||||
|
# Add For Images Only
|
||||||
|
if image_request.get("use_case") == "multimodal"
|
||||||
|
],
|
||||||
messages[-1]["content"]["parts"][0]]
|
messages[-1]["content"]["parts"][0]]
|
||||||
}
|
}
|
||||||
# Add the metadata object with the attachments
|
# Add the metadata object with the attachments
|
||||||
messages[-1]["metadata"] = {
|
messages[-1]["metadata"] = {
|
||||||
"attachments": [{
|
"attachments": [{
|
||||||
"height": image_request.get("height"),
|
|
||||||
"id": image_request.get("file_id"),
|
"id": image_request.get("file_id"),
|
||||||
"mimeType": image_request.get("mime_type"),
|
"mimeType": image_request.get("mime_type"),
|
||||||
"name": image_request.get("file_name"),
|
"name": image_request.get("file_name"),
|
||||||
"size": image_request.get("file_size"),
|
"size": image_request.get("file_size"),
|
||||||
"width": image_request.get("width"),
|
**(
|
||||||
|
{
|
||||||
|
"height": image_request.get("height"),
|
||||||
|
"width": image_request.get("width"),
|
||||||
|
}
|
||||||
|
if image_request.get("use_case") == "multimodal"
|
||||||
|
else {}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for image_request in image_requests]
|
for image_request in image_requests]
|
||||||
}
|
}
|
||||||
|
|
@ -308,6 +336,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
return_conversation: bool = True,
|
return_conversation: bool = True,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
|
conversation_mode=None,
|
||||||
|
temporary=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
"""
|
"""
|
||||||
|
|
@ -353,11 +383,12 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
async with session.get(cls.url, headers=cls._headers) as response:
|
async with session.get(cls.url, headers=cls._headers) as response:
|
||||||
cls._update_request_args(auth_result, session)
|
cls._update_request_args(auth_result, session)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
try:
|
|
||||||
image_requests = await cls.upload_images(session, auth_result, media)
|
# try:
|
||||||
except Exception as e:
|
image_requests = await cls.upload_files(session, auth_result, media)
|
||||||
debug.error("OpenaiChat: Upload image failed")
|
# except Exception as e:
|
||||||
debug.error(e)
|
# debug.error("OpenaiChat: Upload image failed")
|
||||||
|
# debug.error(e)
|
||||||
try:
|
try:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
except ModelNotFoundError:
|
except ModelNotFoundError:
|
||||||
|
|
@ -370,6 +401,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
conversation = Conversation(None, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
conversation = Conversation(None, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
||||||
else:
|
else:
|
||||||
conversation = copy(conversation)
|
conversation = copy(conversation)
|
||||||
|
|
||||||
|
if conversation_mode is None:
|
||||||
|
conversation_mode = {"kind": "primary_assistant"}
|
||||||
|
|
||||||
if getattr(auth_result, "cookies", {}).get("oai-did") != getattr(conversation, "user_id", None):
|
if getattr(auth_result, "cookies", {}).get("oai-did") != getattr(conversation, "user_id", None):
|
||||||
conversation = Conversation(None, str(uuid.uuid4()))
|
conversation = Conversation(None, str(uuid.uuid4()))
|
||||||
if cls._api_key is None:
|
if cls._api_key is None:
|
||||||
|
|
@ -394,6 +429,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
"supports_buffering": True,
|
"supports_buffering": True,
|
||||||
"supported_encodings": ["v1"]
|
"supported_encodings": ["v1"]
|
||||||
}
|
}
|
||||||
|
if temporary:
|
||||||
|
data["history_and_training_disabled"] = True
|
||||||
|
|
||||||
async with session.post(
|
async with session.post(
|
||||||
prepare_url,
|
prepare_url,
|
||||||
json=data,
|
json=data,
|
||||||
|
|
@ -434,11 +472,11 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
proof_token=proof_token
|
proof_token=proof_token
|
||||||
)
|
)
|
||||||
[debug.log(text) for text in (
|
# [debug.log(text) for text in (
|
||||||
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
||||||
#f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
|
#f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
|
||||||
#f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
|
#f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
|
||||||
)]
|
# )]
|
||||||
data = {
|
data = {
|
||||||
"action": "next",
|
"action": "next",
|
||||||
"parent_message_id": conversation.message_id,
|
"parent_message_id": conversation.message_id,
|
||||||
|
|
@ -453,6 +491,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
"client_contextual_info":{"is_dark_mode":False,"time_since_loaded":random.randint(20, 500),"page_height":578,"page_width":1850,"pixel_ratio":1,"screen_height":1080,"screen_width":1920},
|
"client_contextual_info":{"is_dark_mode":False,"time_since_loaded":random.randint(20, 500),"page_height":578,"page_width":1850,"pixel_ratio":1,"screen_height":1080,"screen_width":1920},
|
||||||
"paragen_cot_summary_display_override":"allow"
|
"paragen_cot_summary_display_override":"allow"
|
||||||
}
|
}
|
||||||
|
if temporary:
|
||||||
|
data["history_and_training_disabled"] = True
|
||||||
|
|
||||||
if conversation.conversation_id is not None:
|
if conversation.conversation_id is not None:
|
||||||
data["conversation_id"] = conversation.conversation_id
|
data["conversation_id"] = conversation.conversation_id
|
||||||
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
||||||
|
|
|
||||||
|
|
@ -191,6 +191,124 @@ def is_accepted_format(binary_data: bytes) -> str:
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid image format (from magic code).")
|
raise ValueError("Invalid image format (from magic code).")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def detect_file_type(binary_data: bytes) -> tuple[str, str] | None:
|
||||||
|
"""
|
||||||
|
Detect file type from magic number / header signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binary_data (bytes): File binary data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (extension, MIME type)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If file type is unknown
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ---- Images ----
|
||||||
|
if binary_data.startswith(b"\xff\xd8\xff"):
|
||||||
|
return ".jpg", "image/jpeg"
|
||||||
|
elif binary_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||||
|
return ".png", "image/png"
|
||||||
|
elif binary_data.startswith((b"GIF87a", b"GIF89a")):
|
||||||
|
return ".gif", "image/gif"
|
||||||
|
elif binary_data.startswith(b"RIFF") and binary_data[8:12] == b"WEBP":
|
||||||
|
return ".webp", "image/webp"
|
||||||
|
elif binary_data.startswith(b"BM"):
|
||||||
|
return ".bmp", "image/bmp"
|
||||||
|
elif binary_data.startswith(b"II*\x00") or binary_data.startswith(b"MM\x00*"):
|
||||||
|
return ".tiff", "image/tiff"
|
||||||
|
elif binary_data.startswith(b"\x00\x00\x01\x00"):
|
||||||
|
return ".ico", "image/x-icon"
|
||||||
|
elif binary_data.startswith(b"\x00\x00\x00\x0cjP \x0d\x0a\x87\x0a"):
|
||||||
|
return ".jp2", "image/jp2"
|
||||||
|
elif len(binary_data) > 12 and binary_data[4:8] == b"ftyp":
|
||||||
|
brand = binary_data[8:12]
|
||||||
|
if brand in [b"heic", b"heix", b"hevc", b"hevx", b"mif1", b"msf1"]:
|
||||||
|
return ".heic", "image/heif"
|
||||||
|
elif brand in [b"avif"]:
|
||||||
|
return ".avif", "image/avif"
|
||||||
|
elif binary_data.lstrip().startswith((b"<?xml", b"<svg")):
|
||||||
|
return ".svg", "image/svg+xml"
|
||||||
|
|
||||||
|
# ---- Documents ----
|
||||||
|
elif binary_data.startswith(b"%PDF"):
|
||||||
|
return ".pdf", "application/pdf"
|
||||||
|
elif binary_data.startswith(b"PK\x03\x04"):
|
||||||
|
return".zip", "application/zip-based",
|
||||||
|
# could be docx/xlsx/pptx/jar/apk/odt
|
||||||
|
elif binary_data.startswith(b"\xd0\xcf\x11\xe0"):
|
||||||
|
return ".doc", "application/vnd.ms-office"
|
||||||
|
elif binary_data.startswith(b"{\\rtf"):
|
||||||
|
return ".rtf", "application/rtf"
|
||||||
|
elif binary_data.startswith(b"7z\xbc\xaf\x27\x1c"):
|
||||||
|
return ".7z", "application/x-7z-compressed"
|
||||||
|
elif binary_data.startswith(b"Rar!\x1a\x07\x00"):
|
||||||
|
return ".rar", "application/vnd.rar"
|
||||||
|
elif binary_data.startswith(b"\x1f\x8b"):
|
||||||
|
return ".gz", "application/gzip"
|
||||||
|
elif binary_data.startswith(b"BZh"):
|
||||||
|
return ".bz2", "application/x-bzip2"
|
||||||
|
elif binary_data.startswith(b"\xfd7zXZ\x00"):
|
||||||
|
return ".xz", "application/x-xz"
|
||||||
|
|
||||||
|
# ---- Executables / Libraries ----
|
||||||
|
elif binary_data.startswith(b"MZ"):
|
||||||
|
return ".exe", "application/x-msdownload"
|
||||||
|
elif binary_data.startswith(b"\x7fELF"):
|
||||||
|
return ".elf", "application/x-elf"
|
||||||
|
elif binary_data.startswith(b"\xca\xfe\xba\xbe") or binary_data.startswith(
|
||||||
|
b"\xca\xfe\xd0\x0d"
|
||||||
|
):
|
||||||
|
return ".class", "application/java-vm"
|
||||||
|
elif (
|
||||||
|
binary_data.startswith(b"\x50\x4b\x03\x04")
|
||||||
|
and b"META-INF" in binary_data[:200]
|
||||||
|
):
|
||||||
|
return ".jar", "application/java-archive"
|
||||||
|
|
||||||
|
# ---- Audio ----
|
||||||
|
elif binary_data.startswith(b"ID3") or binary_data[0:2] == b"\xff\xfb":
|
||||||
|
return ".mp3", "audio/mpeg"
|
||||||
|
elif binary_data.startswith(b"OggS"):
|
||||||
|
return ".ogg", "audio/ogg"
|
||||||
|
elif binary_data.startswith(b"fLaC"):
|
||||||
|
return ".flac", "audio/flac"
|
||||||
|
elif binary_data.startswith(b"RIFF") and binary_data[8:12] == b"WAVE":
|
||||||
|
return ".wav", "audio/wav"
|
||||||
|
elif binary_data.startswith(b"MThd"):
|
||||||
|
return ".mid", "audio/midi"
|
||||||
|
|
||||||
|
# ---- Video ----
|
||||||
|
elif binary_data.startswith(b"\x00\x00\x00") and b"ftyp" in binary_data[4:12]:
|
||||||
|
return ".mp4", "video/mp4"
|
||||||
|
elif binary_data.startswith(b"RIFF") and binary_data[8:12] == b"AVI ":
|
||||||
|
return ".avi", "video/x-msvideo"
|
||||||
|
elif binary_data.startswith(b"OggS"):
|
||||||
|
return ".ogv", "video/ogg"
|
||||||
|
elif binary_data.startswith(b"\x1a\x45\xdf\xa3"):
|
||||||
|
return ".mkv", "video/webm"
|
||||||
|
elif binary_data.startswith(b"\x00\x00\x01\xba"):
|
||||||
|
return ".mpg", "video/mpeg"
|
||||||
|
|
||||||
|
# ---- Text / Scripts ----
|
||||||
|
elif binary_data.lstrip().startswith(b"#!"):
|
||||||
|
return ".sh", "text/x-script"
|
||||||
|
elif binary_data.lstrip().startswith((b"{", b"[")):
|
||||||
|
return ".json", "application/json"
|
||||||
|
elif binary_data.lstrip().startswith((b"<", b"<!DOCTYPE")):
|
||||||
|
return ".html", "text/html"
|
||||||
|
elif binary_data.lstrip().startswith(b"<?xml"):
|
||||||
|
return ".xml", "application/xml"
|
||||||
|
elif all(32 <= b <= 127 or b in (9, 10, 13) for b in binary_data[:100]):
|
||||||
|
return ".txt", "text/plain"
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown or unsupported file type")
|
||||||
|
|
||||||
|
|
||||||
def extract_data_uri(data_uri: str) -> bytes:
|
def extract_data_uri(data_uri: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Extracts the binary data from the given data URI.
|
Extracts the binary data from the given data URI.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue