mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
add upload_file
- add upload_file - add conversation_mode - add temporary
This commit is contained in:
parent
9423a23003
commit
fe24210db2
2 changed files with 195 additions and 36 deletions
|
|
@ -22,7 +22,7 @@ from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
|||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests import StreamSession
|
||||
from ...requests import get_nodriver
|
||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format, detect_file_type
|
||||
from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview, ResponseType, format_link
|
||||
from ...providers.response import TitleGeneration, RequestLogin, Reasoning
|
||||
|
|
@ -126,54 +126,66 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
async def upload_images(
|
||||
async def upload_files(
|
||||
cls,
|
||||
session: StreamSession,
|
||||
auth_result: AuthResult,
|
||||
media: MediaListType,
|
||||
) -> ImageRequest:
|
||||
) -> list[ImageRequest]:
|
||||
"""
|
||||
Upload an image to the service and get the download URL
|
||||
|
||||
Args:
|
||||
session: The StreamSession object to use for requests
|
||||
headers: The headers to include in the requests
|
||||
media: The images to upload, either a PIL Image object or a bytes object
|
||||
media: The files to upload, either a PIL Image object or a bytes object
|
||||
|
||||
Returns:
|
||||
An ImageRequest object that contains the download URL, file name, and other data
|
||||
"""
|
||||
async def upload_image(image, image_name):
|
||||
debug.log(f"Uploading image: {image_name}")
|
||||
# Convert the image to a PIL Image object and get the extension
|
||||
data_bytes = to_bytes(image)
|
||||
image = to_image(data_bytes)
|
||||
extension = image.format.lower()
|
||||
async def upload_file(file, image_name=None):
|
||||
debug.log(f"Uploading file: {image_name}")
|
||||
file_data = {}
|
||||
|
||||
data_bytes = to_bytes(file)
|
||||
extension, mime_type = detect_file_type(data_bytes)
|
||||
if "image" in mime_type:
|
||||
# Convert the image to a PIL Image object
|
||||
file = to_image(data_bytes)
|
||||
use_case = "multimodal"
|
||||
file_data.update({"height": file.height, "width": file.width})
|
||||
else:
|
||||
use_case = "my_files"
|
||||
image_name = (
|
||||
f"file-{len(data_bytes)}{extension}"
|
||||
if image_name is None
|
||||
else image_name
|
||||
)
|
||||
data = {
|
||||
"file_name": "" if image_name is None else image_name,
|
||||
"file_name": image_name,
|
||||
"file_size": len(data_bytes),
|
||||
"use_case": "multimodal"
|
||||
"use_case": use_case,
|
||||
}
|
||||
# Post the image data to the service and get the image data
|
||||
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=cls._headers) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response, "Create file failed")
|
||||
image_data = {
|
||||
file_data.update(
|
||||
{
|
||||
**data,
|
||||
**await response.json(),
|
||||
"mime_type": is_accepted_format(data_bytes),
|
||||
"mime_type": mime_type,
|
||||
"extension": extension,
|
||||
"height": image.height,
|
||||
"width": image.width
|
||||
}
|
||||
)
|
||||
# Put the image bytes to the upload URL and check the status
|
||||
await asyncio.sleep(1)
|
||||
async with session.put(
|
||||
image_data["upload_url"],
|
||||
file_data["upload_url"],
|
||||
data=data_bytes,
|
||||
headers={
|
||||
**UPLOAD_HEADERS,
|
||||
"Content-Type": image_data["mime_type"],
|
||||
"Content-Type": file_data["mime_type"],
|
||||
"x-ms-blob-type": "BlockBlob",
|
||||
"x-ms-version": "2020-04-08",
|
||||
"Origin": "https://chatgpt.com",
|
||||
|
|
@ -182,15 +194,22 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
await raise_for_status(response)
|
||||
# Post the file ID to the service and get the download URL
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
|
||||
f"{cls.url}/backend-api/files/{file_data['file_id']}/uploaded",
|
||||
json={},
|
||||
headers=auth_result.headers
|
||||
) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response, "Get download url failed")
|
||||
image_data["download_url"] = (await response.json())["download_url"]
|
||||
return ImageRequest(image_data)
|
||||
return [await upload_image(image, image_name) for image, image_name in media]
|
||||
uploaded_data = await response.json()
|
||||
file_data["download_url"] = uploaded_data["download_url"]
|
||||
return ImageRequest(file_data)
|
||||
|
||||
medias = []
|
||||
for item in media:
|
||||
item = item if isinstance(item, tuple) else (item,)
|
||||
__uploaded_media = await upload_file(*item)
|
||||
medias.append(__uploaded_media)
|
||||
return medias
|
||||
|
||||
@classmethod
|
||||
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
|
||||
|
|
@ -237,19 +256,28 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
"size_bytes": image_request.get("file_size"),
|
||||
"width": image_request.get("width"),
|
||||
}
|
||||
for image_request in image_requests],
|
||||
for image_request in image_requests
|
||||
# Add For Images Only
|
||||
if image_request.get("use_case") == "multimodal"
|
||||
],
|
||||
messages[-1]["content"]["parts"][0]]
|
||||
}
|
||||
# Add the metadata object with the attachments
|
||||
messages[-1]["metadata"] = {
|
||||
"attachments": [{
|
||||
"height": image_request.get("height"),
|
||||
"id": image_request.get("file_id"),
|
||||
"mimeType": image_request.get("mime_type"),
|
||||
"name": image_request.get("file_name"),
|
||||
"size": image_request.get("file_size"),
|
||||
**(
|
||||
{
|
||||
"height": image_request.get("height"),
|
||||
"width": image_request.get("width"),
|
||||
}
|
||||
if image_request.get("use_case") == "multimodal"
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for image_request in image_requests]
|
||||
}
|
||||
return messages
|
||||
|
|
@ -308,6 +336,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
return_conversation: bool = True,
|
||||
web_search: bool = False,
|
||||
prompt: str = None,
|
||||
conversation_mode=None,
|
||||
temporary=False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
|
|
@ -353,11 +383,12 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
async with session.get(cls.url, headers=cls._headers) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
try:
|
||||
image_requests = await cls.upload_images(session, auth_result, media)
|
||||
except Exception as e:
|
||||
debug.error("OpenaiChat: Upload image failed")
|
||||
debug.error(e)
|
||||
|
||||
# try:
|
||||
image_requests = await cls.upload_files(session, auth_result, media)
|
||||
# except Exception as e:
|
||||
# debug.error("OpenaiChat: Upload image failed")
|
||||
# debug.error(e)
|
||||
try:
|
||||
model = cls.get_model(model)
|
||||
except ModelNotFoundError:
|
||||
|
|
@ -370,6 +401,10 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
conversation = Conversation(None, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
||||
else:
|
||||
conversation = copy(conversation)
|
||||
|
||||
if conversation_mode is None:
|
||||
conversation_mode = {"kind": "primary_assistant"}
|
||||
|
||||
if getattr(auth_result, "cookies", {}).get("oai-did") != getattr(conversation, "user_id", None):
|
||||
conversation = Conversation(None, str(uuid.uuid4()))
|
||||
if cls._api_key is None:
|
||||
|
|
@ -394,6 +429,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
"supports_buffering": True,
|
||||
"supported_encodings": ["v1"]
|
||||
}
|
||||
if temporary:
|
||||
data["history_and_training_disabled"] = True
|
||||
|
||||
async with session.post(
|
||||
prepare_url,
|
||||
json=data,
|
||||
|
|
@ -434,11 +472,11 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
user_agent=user_agent,
|
||||
proof_token=proof_token
|
||||
)
|
||||
[debug.log(text) for text in (
|
||||
# [debug.log(text) for text in (
|
||||
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
||||
#f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
|
||||
#f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
|
||||
)]
|
||||
# )]
|
||||
data = {
|
||||
"action": "next",
|
||||
"parent_message_id": conversation.message_id,
|
||||
|
|
@ -453,6 +491,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
"client_contextual_info":{"is_dark_mode":False,"time_since_loaded":random.randint(20, 500),"page_height":578,"page_width":1850,"pixel_ratio":1,"screen_height":1080,"screen_width":1920},
|
||||
"paragen_cot_summary_display_override":"allow"
|
||||
}
|
||||
if temporary:
|
||||
data["history_and_training_disabled"] = True
|
||||
|
||||
if conversation.conversation_id is not None:
|
||||
data["conversation_id"] = conversation.conversation_id
|
||||
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
||||
|
|
|
|||
|
|
@ -191,6 +191,124 @@ def is_accepted_format(binary_data: bytes) -> str:
|
|||
else:
|
||||
raise ValueError("Invalid image format (from magic code).")
|
||||
|
||||
|
||||
|
||||
def detect_file_type(binary_data: bytes) -> tuple[str, str] | None:
|
||||
"""
|
||||
Detect file type from magic number / header signature.
|
||||
|
||||
Args:
|
||||
binary_data (bytes): File binary data
|
||||
|
||||
Returns:
|
||||
tuple: (extension, MIME type)
|
||||
|
||||
Raises:
|
||||
ValueError: If file type is unknown
|
||||
"""
|
||||
|
||||
# ---- Images ----
|
||||
if binary_data.startswith(b"\xff\xd8\xff"):
|
||||
return ".jpg", "image/jpeg"
|
||||
elif binary_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return ".png", "image/png"
|
||||
elif binary_data.startswith((b"GIF87a", b"GIF89a")):
|
||||
return ".gif", "image/gif"
|
||||
elif binary_data.startswith(b"RIFF") and binary_data[8:12] == b"WEBP":
|
||||
return ".webp", "image/webp"
|
||||
elif binary_data.startswith(b"BM"):
|
||||
return ".bmp", "image/bmp"
|
||||
elif binary_data.startswith(b"II*\x00") or binary_data.startswith(b"MM\x00*"):
|
||||
return ".tiff", "image/tiff"
|
||||
elif binary_data.startswith(b"\x00\x00\x01\x00"):
|
||||
return ".ico", "image/x-icon"
|
||||
elif binary_data.startswith(b"\x00\x00\x00\x0cjP \x0d\x0a\x87\x0a"):
|
||||
return ".jp2", "image/jp2"
|
||||
elif len(binary_data) > 12 and binary_data[4:8] == b"ftyp":
|
||||
brand = binary_data[8:12]
|
||||
if brand in [b"heic", b"heix", b"hevc", b"hevx", b"mif1", b"msf1"]:
|
||||
return ".heic", "image/heif"
|
||||
elif brand in [b"avif"]:
|
||||
return ".avif", "image/avif"
|
||||
elif binary_data.lstrip().startswith((b"<?xml", b"<svg")):
|
||||
return ".svg", "image/svg+xml"
|
||||
|
||||
# ---- Documents ----
|
||||
elif binary_data.startswith(b"%PDF"):
|
||||
return ".pdf", "application/pdf"
|
||||
elif binary_data.startswith(b"PK\x03\x04"):
|
||||
return".zip", "application/zip-based",
|
||||
# could be docx/xlsx/pptx/jar/apk/odt
|
||||
elif binary_data.startswith(b"\xd0\xcf\x11\xe0"):
|
||||
return ".doc", "application/vnd.ms-office"
|
||||
elif binary_data.startswith(b"{\\rtf"):
|
||||
return ".rtf", "application/rtf"
|
||||
elif binary_data.startswith(b"7z\xbc\xaf\x27\x1c"):
|
||||
return ".7z", "application/x-7z-compressed"
|
||||
elif binary_data.startswith(b"Rar!\x1a\x07\x00"):
|
||||
return ".rar", "application/vnd.rar"
|
||||
elif binary_data.startswith(b"\x1f\x8b"):
|
||||
return ".gz", "application/gzip"
|
||||
elif binary_data.startswith(b"BZh"):
|
||||
return ".bz2", "application/x-bzip2"
|
||||
elif binary_data.startswith(b"\xfd7zXZ\x00"):
|
||||
return ".xz", "application/x-xz"
|
||||
|
||||
# ---- Executables / Libraries ----
|
||||
elif binary_data.startswith(b"MZ"):
|
||||
return ".exe", "application/x-msdownload"
|
||||
elif binary_data.startswith(b"\x7fELF"):
|
||||
return ".elf", "application/x-elf"
|
||||
elif binary_data.startswith(b"\xca\xfe\xba\xbe") or binary_data.startswith(
|
||||
b"\xca\xfe\xd0\x0d"
|
||||
):
|
||||
return ".class", "application/java-vm"
|
||||
elif (
|
||||
binary_data.startswith(b"\x50\x4b\x03\x04")
|
||||
and b"META-INF" in binary_data[:200]
|
||||
):
|
||||
return ".jar", "application/java-archive"
|
||||
|
||||
# ---- Audio ----
|
||||
elif binary_data.startswith(b"ID3") or binary_data[0:2] == b"\xff\xfb":
|
||||
return ".mp3", "audio/mpeg"
|
||||
elif binary_data.startswith(b"OggS"):
|
||||
return ".ogg", "audio/ogg"
|
||||
elif binary_data.startswith(b"fLaC"):
|
||||
return ".flac", "audio/flac"
|
||||
elif binary_data.startswith(b"RIFF") and binary_data[8:12] == b"WAVE":
|
||||
return ".wav", "audio/wav"
|
||||
elif binary_data.startswith(b"MThd"):
|
||||
return ".mid", "audio/midi"
|
||||
|
||||
# ---- Video ----
|
||||
elif binary_data.startswith(b"\x00\x00\x00") and b"ftyp" in binary_data[4:12]:
|
||||
return ".mp4", "video/mp4"
|
||||
elif binary_data.startswith(b"RIFF") and binary_data[8:12] == b"AVI ":
|
||||
return ".avi", "video/x-msvideo"
|
||||
elif binary_data.startswith(b"OggS"):
|
||||
return ".ogv", "video/ogg"
|
||||
elif binary_data.startswith(b"\x1a\x45\xdf\xa3"):
|
||||
return ".mkv", "video/webm"
|
||||
elif binary_data.startswith(b"\x00\x00\x01\xba"):
|
||||
return ".mpg", "video/mpeg"
|
||||
|
||||
# ---- Text / Scripts ----
|
||||
elif binary_data.lstrip().startswith(b"#!"):
|
||||
return ".sh", "text/x-script"
|
||||
elif binary_data.lstrip().startswith((b"{", b"[")):
|
||||
return ".json", "application/json"
|
||||
elif binary_data.lstrip().startswith((b"<", b"<!DOCTYPE")):
|
||||
return ".html", "text/html"
|
||||
elif binary_data.lstrip().startswith(b"<?xml"):
|
||||
return ".xml", "application/xml"
|
||||
elif all(32 <= b <= 127 or b in (9, 10, 13) for b in binary_data[:100]):
|
||||
return ".txt", "text/plain"
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown or unsupported file type")
|
||||
|
||||
|
||||
def extract_data_uri(data_uri: str) -> bytes:
|
||||
"""
|
||||
Extracts the binary data from the given data URI.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue