mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
feat: improve file handling and streamline provider implementations
- Added file upload usage example with bucket_id in docs/file.md - Fixed ARTA provider by refactoring error handling with a new raise_error function - Simplified aspect_ratio handling in ARTA with proper default - Improved AllenAI provider by cleaning up image handling logic - Fixed Blackbox provider's media handling to properly process images - Updated file tools to handle URL downloads correctly - Fixed bucket_id pattern matching in ToolHandler.process_bucket_tool - Cleaned up imports in typing.py by removing unnecessary sys import - Fixed inconsistent function parameters in g4f/tools/files.py - Fixed return value of upload_and_process function to return bucket_id
This commit is contained in:
parent
fa36dccf16
commit
c083f85206
7 changed files with 82 additions and 104 deletions
31
docs/file.md
31
docs/file.md
|
|
@ -75,16 +75,45 @@ def upload_and_process(files_or_urls, bucket_id=None):
|
|||
else:
|
||||
print(f"Unhandled SSE event: {line}")
|
||||
response.close()
|
||||
return bucket_id5
|
||||
|
||||
# Example with URLs
|
||||
urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}]
|
||||
bucket_id = upload_and_process(urls)
|
||||
|
||||
#Example with files
|
||||
files = {'files': open('document.pdf', 'rb'), 'files': open('data.json', 'rb')}
|
||||
files = {'files': ('document.pdf', open('document.pdf', 'rb'))}
|
||||
bucket_id = upload_and_process(files)
|
||||
```
|
||||
|
||||
**Usage of Uploaded Files:**
|
||||
```python
|
||||
from g4f.client import Client
|
||||
|
||||
# Enable debug mode
|
||||
import g4f.debug
|
||||
g4f.debug.logging = True
|
||||
|
||||
client = Client()
|
||||
|
||||
# Upload example file
|
||||
files = {'files': ('demo.docx', open('demo.docx', 'rb'))}
|
||||
bucket_id = upload_and_process(files)
|
||||
|
||||
# Send request with file:
|
||||
response = client.chat.completions.create(
|
||||
[{"role": "user", "content": [
|
||||
{"type": "text", "text": "Discribe this file."},
|
||||
{"bucket_id": bucket_id}
|
||||
]}],
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
**Example Output:**
|
||||
```
|
||||
This document is a demonstration of the DOCX Input plugin capabilities in the software ...
|
||||
```
|
||||
|
||||
**Example Usage (JavaScript):**
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import time
|
|||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from aiohttp import ClientSession
|
||||
from aiohttp import ClientSession, ClientResponse
|
||||
import asyncio
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
|
|
@ -92,17 +92,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
# Step 1: Generate Authentication Token
|
||||
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
|
||||
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
|
||||
if auth_response.status >= 400:
|
||||
error_text = await auth_response.text()
|
||||
raise ResponseError(f"Failed to obtain authentication token. Status: {auth_response.status}, Response: {error_text}")
|
||||
|
||||
try:
|
||||
auth_data = await auth_response.json()
|
||||
except Exception as e:
|
||||
error_text = await auth_response.text()
|
||||
content_type = auth_response.headers.get('Content-Type', 'unknown')
|
||||
raise ResponseError(f"Failed to parse auth response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
||||
|
||||
await raise_error(f"Failed to obtain authentication token", auth_response)
|
||||
auth_data = await auth_response.json()
|
||||
auth_token = auth_data.get("idToken")
|
||||
#refresh_token = auth_data.get("refreshToken")
|
||||
if not auth_token:
|
||||
|
|
@ -118,17 +109,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"refresh_token": refresh_token,
|
||||
}
|
||||
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
|
||||
if response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise ResponseError(f"Failed to refresh token. Status: {response.status}, Response: {error_text}")
|
||||
|
||||
try:
|
||||
response_data = await response.json()
|
||||
except Exception as e:
|
||||
error_text = await response.text()
|
||||
content_type = response.headers.get('Content-Type', 'unknown')
|
||||
raise ResponseError(f"Failed to parse token refresh response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
||||
|
||||
await raise_error(f"Failed to refresh token", response)
|
||||
response_data = await response.json()
|
||||
return response_data.get("id_token"), response_data.get("refresh_token")
|
||||
|
||||
@classmethod
|
||||
|
|
@ -156,7 +138,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
n: int = 1,
|
||||
guidance_scale: int = 7,
|
||||
num_inference_steps: int = 30,
|
||||
aspect_ratio: str = "1:1",
|
||||
aspect_ratio: str = None,
|
||||
seed: int = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
|
|
@ -179,7 +161,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"images_num": str(n),
|
||||
"cfg_scale": str(guidance_scale),
|
||||
"steps": str(num_inference_steps),
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"aspect_ratio": "1:1" if aspect_ratio is None else aspect_ratio,
|
||||
"seed": str(seed),
|
||||
}
|
||||
|
||||
|
|
@ -188,45 +170,26 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
}
|
||||
|
||||
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
|
||||
if image_response.status >= 400:
|
||||
error_text = await image_response.text()
|
||||
raise ResponseError(f"Failed to initiate image generation. Status: {image_response.status}, Response: {error_text}")
|
||||
|
||||
try:
|
||||
image_data = await image_response.json()
|
||||
except Exception as e:
|
||||
error_text = await image_response.text()
|
||||
content_type = image_response.headers.get('Content-Type', 'unknown')
|
||||
raise ResponseError(f"Failed to parse response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
||||
|
||||
await raise_error(f"Failed to initiate image generation", image_response)
|
||||
image_data = await image_response.json()
|
||||
record_id = image_data.get("record_id")
|
||||
if not record_id:
|
||||
raise ResponseError(f"Failed to initiate image generation: {image_data}")
|
||||
|
||||
# Step 3: Check Generation Status
|
||||
status_url = cls.status_check_url.format(record_id=record_id)
|
||||
counter = 4
|
||||
start_time = time.time()
|
||||
last_status = None
|
||||
while True:
|
||||
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
|
||||
if status_response.status >= 400:
|
||||
error_text = await status_response.text()
|
||||
raise ResponseError(f"Failed to check image generation status. Status: {status_response.status}, Response: {error_text}")
|
||||
|
||||
try:
|
||||
status_data = await status_response.json()
|
||||
except Exception as e:
|
||||
error_text = await status_response.text()
|
||||
content_type = status_response.headers.get('Content-Type', 'unknown')
|
||||
raise ResponseError(f"Failed to parse status response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
||||
|
||||
await raise_error(f"Failed to check image generation status", status_response)
|
||||
status_data = await status_response.json()
|
||||
status = status_data.get("status")
|
||||
|
||||
if status == "DONE":
|
||||
image_urls = [image["url"] for image in status_data.get("response", [])]
|
||||
duration = time.time() - start_time
|
||||
yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s")
|
||||
yield Reasoning(label="Generated", status=f"{n} image in {duration:.2f}s" if n == 1 else f"{n} images in {duration:.2f}s")
|
||||
yield ImageResponse(urls=image_urls, alt=prompt)
|
||||
return
|
||||
elif status in ("IN_QUEUE", "IN_PROGRESS"):
|
||||
|
|
@ -239,3 +202,10 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
await asyncio.sleep(2) # Poll every 2 seconds
|
||||
else:
|
||||
raise ResponseError(f"Image generation failed with status: {status}")
|
||||
|
||||
async def raise_error(response: ClientResponse, message: str):
|
||||
if response.ok:
|
||||
return
|
||||
error_text = await response.text()
|
||||
content_type = response.headers.get('Content-Type', 'unknown')
|
||||
raise ResponseError(f"{message}. Content-Type: {content_type}, Response: {error_text}")
|
||||
|
|
@ -83,17 +83,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
) -> AsyncResult:
|
||||
actual_model = cls.get_model(model)
|
||||
|
||||
# Use format_image_prompt for vision models when media is provided
|
||||
if media is not None and len(media) > 0:
|
||||
# For vision models, use format_image_prompt
|
||||
if actual_model in cls.vision_models:
|
||||
prompt = format_image_prompt(messages)
|
||||
else:
|
||||
# For non-vision models with images, still use the last user message
|
||||
prompt = get_last_user_message(messages)
|
||||
else:
|
||||
# For text-only messages, use the standard format
|
||||
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||
|
||||
# Determine the correct host for the model
|
||||
if host is None:
|
||||
|
|
@ -157,18 +147,16 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if media is not None and len(media) > 0:
|
||||
conversation = Conversation(actual_model)
|
||||
|
||||
# Add image if provided
|
||||
if media is not None and len(media) > 0:
|
||||
# For each image in the media list (using merge_media to handle different formats)
|
||||
for image, image_name in merge_media(media, messages):
|
||||
image_bytes = to_bytes(image)
|
||||
form_data.extend([
|
||||
f'--{boundary}\r\n'
|
||||
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
|
||||
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
|
||||
])
|
||||
form_data.append(image_bytes.decode('latin1'))
|
||||
form_data.append('\r\n')
|
||||
# For each image in the media list (using merge_media to handle different formats)
|
||||
for image, image_name in merge_media(media, messages):
|
||||
image_bytes = to_bytes(image)
|
||||
form_data.extend([
|
||||
f'--{boundary}\r\n'
|
||||
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
|
||||
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
|
||||
])
|
||||
form_data.append(image_bytes.decode('latin1'))
|
||||
form_data.append('\r\n')
|
||||
|
||||
form_data.append(f'--{boundary}--\r\n')
|
||||
data = "".join(form_data).encode('latin1')
|
||||
|
|
@ -182,11 +170,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
await raise_for_status(response)
|
||||
current_parent = None
|
||||
|
||||
async for chunk in response.content:
|
||||
if not chunk:
|
||||
continue
|
||||
decoded = chunk.decode(errors="ignore")
|
||||
for line in decoded.splitlines():
|
||||
async for line in response.content:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -581,14 +581,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
}
|
||||
current_messages.append(current_msg)
|
||||
|
||||
if media is not None:
|
||||
media = list(merge_media(media, messages))
|
||||
if media:
|
||||
current_messages[-1]['data'] = {
|
||||
"imagesData": [
|
||||
{
|
||||
"filePath": f"/{image_name}",
|
||||
"contents": to_data_uri(image)
|
||||
}
|
||||
for image, image_name in merge_media(media, messages)
|
||||
for image, image_name in media
|
||||
],
|
||||
"fileText": "",
|
||||
"title": ""
|
||||
|
|
|
|||
|
|
@ -518,11 +518,12 @@ async def async_read_and_download_urls(bucket_dir: Path, delete_files: bool = Fa
|
|||
if urls:
|
||||
count = 0
|
||||
with open(os.path.join(bucket_dir, FILE_LIST), 'a') as f:
|
||||
async for filename in download_urls(bucket_dir, urls):
|
||||
f.write(f"{filename}\n")
|
||||
if event_stream:
|
||||
count += 1
|
||||
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
|
||||
for url in urls:
|
||||
async for filename in download_urls(bucket_dir, **url):
|
||||
f.write(f"{filename}\n")
|
||||
if event_stream:
|
||||
count += 1
|
||||
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
|
||||
|
||||
def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]:
|
||||
size = 0
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class ToolHandler:
|
|||
has_bucket = False
|
||||
for message in messages:
|
||||
if "content" in message and isinstance(message["content"], str):
|
||||
new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
|
||||
new_message_content = re.sub(r'{"bucket_id":\s*"([^"]*)"}', on_bucket, message["content"])
|
||||
if new_message_content != message["content"]:
|
||||
has_bucket = True
|
||||
message["content"] = new_message_content
|
||||
|
|
@ -120,7 +120,6 @@ class ToolHandler:
|
|||
|
||||
return messages, sources, extra_kwargs
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Handles API key management"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import sys
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional
|
||||
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional, TypedDict
|
||||
|
||||
try:
|
||||
from PIL.Image import Image
|
||||
|
|
@ -8,11 +7,6 @@ except ImportError:
|
|||
class Image:
|
||||
pass
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .providers.response import ResponseType
|
||||
|
||||
SHA256 = NewType('sha_256_hash', str)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue