feat: improve file handling and streamline provider implementations

- Added file upload usage example with bucket_id in docs/file.md
- Fixed ARTA provider by refactoring error handling with a new raise_error function
- Simplified aspect_ratio handling in ARTA with proper default
- Improved AllenAI provider by cleaning up image handling logic
- Fixed Blackbox provider's media handling to properly process images
- Updated file tools to handle URL downloads correctly
- Fixed bucket_id pattern matching in ToolHandler.process_bucket_tool
- Cleaned up imports in typing.py by removing unnecessary sys import
- Fixed inconsistent function parameters in g4f/tools/files.py
- Fixed return value of upload_and_process function to return bucket_id
This commit is contained in:
hlohaus 2025-04-08 19:00:44 +02:00
parent fa36dccf16
commit c083f85206
7 changed files with 82 additions and 104 deletions

View file

@ -75,16 +75,45 @@ def upload_and_process(files_or_urls, bucket_id=None):
else: else:
print(f"Unhandled SSE event: {line}") print(f"Unhandled SSE event: {line}")
response.close() response.close()
return bucket_id5
# Example with URLs # Example with URLs
urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}] urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}]
bucket_id = upload_and_process(urls) bucket_id = upload_and_process(urls)
#Example with files #Example with files
files = {'files': open('document.pdf', 'rb'), 'files': open('data.json', 'rb')} files = {'files': ('document.pdf', open('document.pdf', 'rb'))}
bucket_id = upload_and_process(files) bucket_id = upload_and_process(files)
``` ```
**Usage of Uploaded Files:**
```python
from g4f.client import Client
# Enable debug mode
import g4f.debug
g4f.debug.logging = True
client = Client()
# Upload example file
files = {'files': ('demo.docx', open('demo.docx', 'rb'))}
bucket_id = upload_and_process(files)
# Send request with file:
response = client.chat.completions.create(
[{"role": "user", "content": [
{"type": "text", "text": "Discribe this file."},
{"bucket_id": bucket_id}
]}],
)
print(response.choices[0].message.content)
```
**Example Output:**
```
This document is a demonstration of the DOCX Input plugin capabilities in the software ...
```
**Example Usage (JavaScript):** **Example Usage (JavaScript):**

View file

@ -5,7 +5,7 @@ import time
import json import json
import random import random
from pathlib import Path from pathlib import Path
from aiohttp import ClientSession from aiohttp import ClientSession, ClientResponse
import asyncio import asyncio
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
@ -92,17 +92,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
# Step 1: Generate Authentication Token # Step 1: Generate Authentication Token
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"} auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response: async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
if auth_response.status >= 400: await raise_error(f"Failed to obtain authentication token", auth_response)
error_text = await auth_response.text() auth_data = await auth_response.json()
raise ResponseError(f"Failed to obtain authentication token. Status: {auth_response.status}, Response: {error_text}")
try:
auth_data = await auth_response.json()
except Exception as e:
error_text = await auth_response.text()
content_type = auth_response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse auth response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
auth_token = auth_data.get("idToken") auth_token = auth_data.get("idToken")
#refresh_token = auth_data.get("refreshToken") #refresh_token = auth_data.get("refreshToken")
if not auth_token: if not auth_token:
@ -118,17 +109,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
"refresh_token": refresh_token, "refresh_token": refresh_token,
} }
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response: async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
if response.status >= 400: await raise_error(f"Failed to refresh token", response)
error_text = await response.text() response_data = await response.json()
raise ResponseError(f"Failed to refresh token. Status: {response.status}, Response: {error_text}")
try:
response_data = await response.json()
except Exception as e:
error_text = await response.text()
content_type = response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse token refresh response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
return response_data.get("id_token"), response_data.get("refresh_token") return response_data.get("id_token"), response_data.get("refresh_token")
@classmethod @classmethod
@ -156,7 +138,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
n: int = 1, n: int = 1,
guidance_scale: int = 7, guidance_scale: int = 7,
num_inference_steps: int = 30, num_inference_steps: int = 30,
aspect_ratio: str = "1:1", aspect_ratio: str = None,
seed: int = None, seed: int = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
@ -179,7 +161,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
"images_num": str(n), "images_num": str(n),
"cfg_scale": str(guidance_scale), "cfg_scale": str(guidance_scale),
"steps": str(num_inference_steps), "steps": str(num_inference_steps),
"aspect_ratio": aspect_ratio, "aspect_ratio": "1:1" if aspect_ratio is None else aspect_ratio,
"seed": str(seed), "seed": str(seed),
} }
@ -188,45 +170,26 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
} }
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response: async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
if image_response.status >= 400: await raise_error(f"Failed to initiate image generation", image_response)
error_text = await image_response.text() image_data = await image_response.json()
raise ResponseError(f"Failed to initiate image generation. Status: {image_response.status}, Response: {error_text}")
try:
image_data = await image_response.json()
except Exception as e:
error_text = await image_response.text()
content_type = image_response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
record_id = image_data.get("record_id") record_id = image_data.get("record_id")
if not record_id: if not record_id:
raise ResponseError(f"Failed to initiate image generation: {image_data}") raise ResponseError(f"Failed to initiate image generation: {image_data}")
# Step 3: Check Generation Status # Step 3: Check Generation Status
status_url = cls.status_check_url.format(record_id=record_id) status_url = cls.status_check_url.format(record_id=record_id)
counter = 4
start_time = time.time() start_time = time.time()
last_status = None last_status = None
while True: while True:
async with session.get(status_url, headers=headers, proxy=proxy) as status_response: async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
if status_response.status >= 400: await raise_error(f"Failed to check image generation status", status_response)
error_text = await status_response.text() status_data = await status_response.json()
raise ResponseError(f"Failed to check image generation status. Status: {status_response.status}, Response: {error_text}")
try:
status_data = await status_response.json()
except Exception as e:
error_text = await status_response.text()
content_type = status_response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"Failed to parse status response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
status = status_data.get("status") status = status_data.get("status")
if status == "DONE": if status == "DONE":
image_urls = [image["url"] for image in status_data.get("response", [])] image_urls = [image["url"] for image in status_data.get("response", [])]
duration = time.time() - start_time duration = time.time() - start_time
yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s") yield Reasoning(label="Generated", status=f"{n} image in {duration:.2f}s" if n == 1 else f"{n} images in {duration:.2f}s")
yield ImageResponse(urls=image_urls, alt=prompt) yield ImageResponse(urls=image_urls, alt=prompt)
return return
elif status in ("IN_QUEUE", "IN_PROGRESS"): elif status in ("IN_QUEUE", "IN_PROGRESS"):
@ -238,4 +201,11 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
yield Reasoning(label="Generating") yield Reasoning(label="Generating")
await asyncio.sleep(2) # Poll every 2 seconds await asyncio.sleep(2) # Poll every 2 seconds
else: else:
raise ResponseError(f"Image generation failed with status: {status}") raise ResponseError(f"Image generation failed with status: {status}")
async def raise_error(response: ClientResponse, message: str):
if response.ok:
return
error_text = await response.text()
content_type = response.headers.get('Content-Type', 'unknown')
raise ResponseError(f"{message}. Content-Type: {content_type}, Response: {error_text}")

View file

@ -83,17 +83,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult: ) -> AsyncResult:
actual_model = cls.get_model(model) actual_model = cls.get_model(model)
# Use format_image_prompt for vision models when media is provided prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
if media is not None and len(media) > 0:
# For vision models, use format_image_prompt
if actual_model in cls.vision_models:
prompt = format_image_prompt(messages)
else:
# For non-vision models with images, still use the last user message
prompt = get_last_user_message(messages)
else:
# For text-only messages, use the standard format
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
# Determine the correct host for the model # Determine the correct host for the model
if host is None: if host is None:
@ -157,18 +147,16 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
if media is not None and len(media) > 0: if media is not None and len(media) > 0:
conversation = Conversation(actual_model) conversation = Conversation(actual_model)
# Add image if provided # For each image in the media list (using merge_media to handle different formats)
if media is not None and len(media) > 0: for image, image_name in merge_media(media, messages):
# For each image in the media list (using merge_media to handle different formats) image_bytes = to_bytes(image)
for image, image_name in merge_media(media, messages): form_data.extend([
image_bytes = to_bytes(image) f'--{boundary}\r\n'
form_data.extend([ f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
f'--{boundary}\r\n' f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n' ])
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n' form_data.append(image_bytes.decode('latin1'))
]) form_data.append('\r\n')
form_data.append(image_bytes.decode('latin1'))
form_data.append('\r\n')
form_data.append(f'--{boundary}--\r\n') form_data.append(f'--{boundary}--\r\n')
data = "".join(form_data).encode('latin1') data = "".join(form_data).encode('latin1')
@ -182,11 +170,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
await raise_for_status(response) await raise_for_status(response)
current_parent = None current_parent = None
async for chunk in response.content: async for line in response.content:
if not chunk:
continue
decoded = chunk.decode(errors="ignore")
for line in decoded.splitlines():
line = line.strip() line = line.strip()
if not line: if not line:
continue continue

View file

@ -581,14 +581,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
} }
current_messages.append(current_msg) current_messages.append(current_msg)
if media is not None: media = list(merge_media(media, messages))
if media:
current_messages[-1]['data'] = { current_messages[-1]['data'] = {
"imagesData": [ "imagesData": [
{ {
"filePath": f"/{image_name}", "filePath": f"/{image_name}",
"contents": to_data_uri(image) "contents": to_data_uri(image)
} }
for image, image_name in merge_media(media, messages) for image, image_name in media
], ],
"fileText": "", "fileText": "",
"title": "" "title": ""

View file

@ -518,11 +518,12 @@ async def async_read_and_download_urls(bucket_dir: Path, delete_files: bool = Fa
if urls: if urls:
count = 0 count = 0
with open(os.path.join(bucket_dir, FILE_LIST), 'a') as f: with open(os.path.join(bucket_dir, FILE_LIST), 'a') as f:
async for filename in download_urls(bucket_dir, urls): for url in urls:
f.write(f"{filename}\n") async for filename in download_urls(bucket_dir, **url):
if event_stream: f.write(f"{filename}\n")
count += 1 if event_stream:
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n' count += 1
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]: def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]:
size = 0 size = 0

View file

@ -80,7 +80,7 @@ class ToolHandler:
has_bucket = False has_bucket = False
for message in messages: for message in messages:
if "content" in message and isinstance(message["content"], str): if "content" in message and isinstance(message["content"], str):
new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"]) new_message_content = re.sub(r'{"bucket_id":\s*"([^"]*)"}', on_bucket, message["content"])
if new_message_content != message["content"]: if new_message_content != message["content"]:
has_bucket = True has_bucket = True
message["content"] = new_message_content message["content"] = new_message_content
@ -97,29 +97,28 @@ class ToolHandler:
"""Process all tool calls and return updated messages and kwargs""" """Process all tool calls and return updated messages and kwargs"""
if not tool_calls: if not tool_calls:
return messages, {} return messages, {}
extra_kwargs = {} extra_kwargs = {}
messages = messages.copy() messages = messages.copy()
sources = None sources = None
for tool in tool_calls: for tool in tool_calls:
if tool.get("type") != "function": if tool.get("type") != "function":
continue continue
function_name = tool.get("function", {}).get("name") function_name = tool.get("function", {}).get("name")
if function_name == TOOL_NAMES["SEARCH"]: if function_name == TOOL_NAMES["SEARCH"]:
messages, sources = await ToolHandler.process_search_tool(messages, tool) messages, sources = await ToolHandler.process_search_tool(messages, tool)
elif function_name == TOOL_NAMES["CONTINUE"]: elif function_name == TOOL_NAMES["CONTINUE"]:
messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider) messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider)
extra_kwargs.update(kwargs) extra_kwargs.update(kwargs)
elif function_name == TOOL_NAMES["BUCKET"]: elif function_name == TOOL_NAMES["BUCKET"]:
messages = ToolHandler.process_bucket_tool(messages, tool) messages = ToolHandler.process_bucket_tool(messages, tool)
return messages, sources, extra_kwargs
return messages, sources, extra_kwargs
class AuthManager: class AuthManager:
"""Handles API key management""" """Handles API key management"""
@ -128,13 +127,13 @@ class AuthManager:
def get_api_key_file(cls) -> Path: def get_api_key_file(cls) -> Path:
"""Get the path to the API key file for a provider""" """Get the path to the API key file for a provider"""
return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json" return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@staticmethod @staticmethod
def load_api_key(provider: Any) -> Optional[str]: def load_api_key(provider: Any) -> Optional[str]:
"""Load API key from config file if needed""" """Load API key from config file if needed"""
if not getattr(provider, "needs_auth", False): if not getattr(provider, "needs_auth", False):
return None return None
auth_file = AuthManager.get_api_key_file(provider) auth_file = AuthManager.get_api_key_file(provider)
try: try:
if auth_file.exists(): if auth_file.exists():

View file

@ -1,6 +1,5 @@
import sys
import os import os
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional, TypedDict
try: try:
from PIL.Image import Image from PIL.Image import Image
@ -8,11 +7,6 @@ except ImportError:
class Image: class Image:
pass pass
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from .providers.response import ResponseType from .providers.response import ResponseType
SHA256 = NewType('sha_256_hash', str) SHA256 = NewType('sha_256_hash', str)