feat: improve file handling and streamline provider implementations

- Added file upload usage example with bucket_id in docs/file.md
- Fixed ARTA provider by refactoring error handling with a new raise_error function
- Simplified aspect_ratio handling in ARTA with proper default
- Improved AllenAI provider by cleaning up image handling logic
- Fixed Blackbox provider's media handling to properly process images
- Updated file tools to handle URL downloads correctly
- Fixed bucket_id pattern matching in ToolHandler.process_bucket_tool
- Cleaned up imports in typing.py by removing unnecessary sys import
- Fixed inconsistent function parameters in g4f/tools/files.py
- Fixed return value of upload_and_process function to return bucket_id
This commit is contained in:
hlohaus 2025-04-08 19:00:44 +02:00
parent fa36dccf16
commit c083f85206
7 changed files with 82 additions and 104 deletions

View file

@ -75,16 +75,45 @@ def upload_and_process(files_or_urls, bucket_id=None):
else:
print(f"Unhandled SSE event: {line}")
response.close()
return bucket_id5
# Example with URLs
urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}]
bucket_id = upload_and_process(urls)
#Example with files
files = {'files': open('document.pdf', 'rb'), 'files': open('data.json', 'rb')}
files = {'files': ('document.pdf', open('document.pdf', 'rb'))}
bucket_id = upload_and_process(files)
```
**Usage of Uploaded Files:**
```python
from g4f.client import Client
# Enable debug mode
import g4f.debug
g4f.debug.logging = True
client = Client()
# Upload example file
files = {'files': ('demo.docx', open('demo.docx', 'rb'))}
bucket_id = upload_and_process(files)
# Send request with file:
response = client.chat.completions.create(
[{"role": "user", "content": [
{"type": "text", "text": "Discribe this file."},
{"bucket_id": bucket_id}
]}],
)
print(response.choices[0].message.content)
```
**Example Output:**
```
This document is a demonstration of the DOCX Input plugin capabilities in the software ...
```
**Example Usage (JavaScript):**

View file

@ -5,7 +5,7 @@ import time
import json
import random
from pathlib import Path
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientResponse
import asyncio
from ..typing import AsyncResult, Messages
@ -92,17 +92,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
# Step 1: Generate Authentication Token
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
if auth_response.status >= 400:
error_text = await auth_response.text()
raise ResponseError(f"Failed to obtain authentication token. Status: {auth_response.status}, Response: {error_text}")
try:
auth_data = await auth_response.json()
except Exception as e:
error_text = await auth_response.text()
content_type = auth_response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse auth response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
await raise_error(f"Failed to obtain authentication token", auth_response)
auth_data = await auth_response.json()
auth_token = auth_data.get("idToken")
#refresh_token = auth_data.get("refreshToken")
if not auth_token:
@ -118,17 +109,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
"refresh_token": refresh_token,
}
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
if response.status >= 400:
error_text = await response.text()
raise ResponseError(f"Failed to refresh token. Status: {response.status}, Response: {error_text}")
try:
response_data = await response.json()
except Exception as e:
error_text = await response.text()
content_type = response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse token refresh response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
await raise_error(f"Failed to refresh token", response)
response_data = await response.json()
return response_data.get("id_token"), response_data.get("refresh_token")
@classmethod
@ -156,7 +138,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
n: int = 1,
guidance_scale: int = 7,
num_inference_steps: int = 30,
aspect_ratio: str = "1:1",
aspect_ratio: str = None,
seed: int = None,
**kwargs
) -> AsyncResult:
@ -179,7 +161,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
"images_num": str(n),
"cfg_scale": str(guidance_scale),
"steps": str(num_inference_steps),
"aspect_ratio": aspect_ratio,
"aspect_ratio": "1:1" if aspect_ratio is None else aspect_ratio,
"seed": str(seed),
}
@ -188,45 +170,26 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
}
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
if image_response.status >= 400:
error_text = await image_response.text()
raise ResponseError(f"Failed to initiate image generation. Status: {image_response.status}, Response: {error_text}")
try:
image_data = await image_response.json()
except Exception as e:
error_text = await image_response.text()
content_type = image_response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
await raise_error(f"Failed to initiate image generation", image_response)
image_data = await image_response.json()
record_id = image_data.get("record_id")
if not record_id:
raise ResponseError(f"Failed to initiate image generation: {image_data}")
# Step 3: Check Generation Status
status_url = cls.status_check_url.format(record_id=record_id)
counter = 4
start_time = time.time()
last_status = None
while True:
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
if status_response.status >= 400:
error_text = await status_response.text()
raise ResponseError(f"Failed to check image generation status. Status: {status_response.status}, Response: {error_text}")
try:
status_data = await status_response.json()
except Exception as e:
error_text = await status_response.text()
content_type = status_response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse status response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
await raise_error(f"Failed to check image generation status", status_response)
status_data = await status_response.json()
status = status_data.get("status")
if status == "DONE":
image_urls = [image["url"] for image in status_data.get("response", [])]
duration = time.time() - start_time
yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s")
yield Reasoning(label="Generated", status=f"{n} image in {duration:.2f}s" if n == 1 else f"{n} images in {duration:.2f}s")
yield ImageResponse(urls=image_urls, alt=prompt)
return
elif status in ("IN_QUEUE", "IN_PROGRESS"):
@ -238,4 +201,11 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
yield Reasoning(label="Generating")
await asyncio.sleep(2) # Poll every 2 seconds
else:
raise ResponseError(f"Image generation failed with status: {status}")
raise ResponseError(f"Image generation failed with status: {status}")
async def raise_error(response: ClientResponse, message: str):
if response.ok:
return
error_text = await response.text()
content_type = response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"{message}. Content-Type: {content_type}, Response: {error_text}")

View file

@ -83,17 +83,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
actual_model = cls.get_model(model)
# Use format_image_prompt for vision models when media is provided
if media is not None and len(media) > 0:
# For vision models, use format_image_prompt
if actual_model in cls.vision_models:
prompt = format_image_prompt(messages)
else:
# For non-vision models with images, still use the last user message
prompt = get_last_user_message(messages)
else:
# For text-only messages, use the standard format
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
# Determine the correct host for the model
if host is None:
@ -157,18 +147,16 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
if media is not None and len(media) > 0:
conversation = Conversation(actual_model)
# Add image if provided
if media is not None and len(media) > 0:
# For each image in the media list (using merge_media to handle different formats)
for image, image_name in merge_media(media, messages):
image_bytes = to_bytes(image)
form_data.extend([
f'--{boundary}\r\n'
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
])
form_data.append(image_bytes.decode('latin1'))
form_data.append('\r\n')
# For each image in the media list (using merge_media to handle different formats)
for image, image_name in merge_media(media, messages):
image_bytes = to_bytes(image)
form_data.extend([
f'--{boundary}\r\n'
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
])
form_data.append(image_bytes.decode('latin1'))
form_data.append('\r\n')
form_data.append(f'--{boundary}--\r\n')
data = "".join(form_data).encode('latin1')
@ -182,11 +170,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
await raise_for_status(response)
current_parent = None
async for chunk in response.content:
if not chunk:
continue
decoded = chunk.decode(errors="ignore")
for line in decoded.splitlines():
async for line in response.content:
line = line.strip()
if not line:
continue

View file

@ -581,14 +581,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
}
current_messages.append(current_msg)
if media is not None:
media = list(merge_media(media, messages))
if media:
current_messages[-1]['data'] = {
"imagesData": [
{
"filePath": f"/{image_name}",
"contents": to_data_uri(image)
}
for image, image_name in merge_media(media, messages)
for image, image_name in media
],
"fileText": "",
"title": ""

View file

@ -518,11 +518,12 @@ async def async_read_and_download_urls(bucket_dir: Path, delete_files: bool = Fa
if urls:
count = 0
with open(os.path.join(bucket_dir, FILE_LIST), 'a') as f:
async for filename in download_urls(bucket_dir, urls):
f.write(f"{filename}\n")
if event_stream:
count += 1
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
for url in urls:
async for filename in download_urls(bucket_dir, **url):
f.write(f"{filename}\n")
if event_stream:
count += 1
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]:
size = 0

View file

@ -80,7 +80,7 @@ class ToolHandler:
has_bucket = False
for message in messages:
if "content" in message and isinstance(message["content"], str):
new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
new_message_content = re.sub(r'{"bucket_id":\s*"([^"]*)"}', on_bucket, message["content"])
if new_message_content != message["content"]:
has_bucket = True
message["content"] = new_message_content
@ -97,29 +97,28 @@ class ToolHandler:
"""Process all tool calls and return updated messages and kwargs"""
if not tool_calls:
return messages, {}
extra_kwargs = {}
messages = messages.copy()
sources = None
for tool in tool_calls:
if tool.get("type") != "function":
continue
function_name = tool.get("function", {}).get("name")
if function_name == TOOL_NAMES["SEARCH"]:
messages, sources = await ToolHandler.process_search_tool(messages, tool)
elif function_name == TOOL_NAMES["CONTINUE"]:
messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider)
extra_kwargs.update(kwargs)
elif function_name == TOOL_NAMES["BUCKET"]:
messages = ToolHandler.process_bucket_tool(messages, tool)
return messages, sources, extra_kwargs
return messages, sources, extra_kwargs
class AuthManager:
"""Handles API key management"""
@ -128,13 +127,13 @@ class AuthManager:
def get_api_key_file(cls) -> Path:
"""Get the path to the API key file for a provider"""
return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@staticmethod
def load_api_key(provider: Any) -> Optional[str]:
"""Load API key from config file if needed"""
if not getattr(provider, "needs_auth", False):
return None
auth_file = AuthManager.get_api_key_file(provider)
try:
if auth_file.exists():

View file

@ -1,6 +1,5 @@
import sys
import os
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional, TypedDict
try:
from PIL.Image import Image
@ -8,11 +7,6 @@ except ImportError:
class Image:
pass
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from .providers.response import ResponseType
SHA256 = NewType('sha_256_hash', str)