mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
feat: improve file handling and streamline provider implementations
- Added file upload usage example with bucket_id in docs/file.md - Fixed ARTA provider by refactoring error handling with a new raise_error function - Simplified aspect_ratio handling in ARTA with proper default - Improved AllenAI provider by cleaning up image handling logic - Fixed Blackbox provider's media handling to properly process images - Updated file tools to handle URL downloads correctly - Fixed bucket_id pattern matching in ToolHandler.process_bucket_tool - Cleaned up imports in typing.py by removing unnecessary sys import - Fixed inconsistent function parameters in g4f/tools/files.py - Fixed return value of upload_and_process function to return bucket_id
This commit is contained in:
parent
fa36dccf16
commit
c083f85206
7 changed files with 82 additions and 104 deletions
31
docs/file.md
31
docs/file.md
|
|
@ -75,16 +75,45 @@ def upload_and_process(files_or_urls, bucket_id=None):
|
||||||
else:
|
else:
|
||||||
print(f"Unhandled SSE event: {line}")
|
print(f"Unhandled SSE event: {line}")
|
||||||
response.close()
|
response.close()
|
||||||
|
return bucket_id5
|
||||||
|
|
||||||
# Example with URLs
|
# Example with URLs
|
||||||
urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}]
|
urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}]
|
||||||
bucket_id = upload_and_process(urls)
|
bucket_id = upload_and_process(urls)
|
||||||
|
|
||||||
#Example with files
|
#Example with files
|
||||||
files = {'files': open('document.pdf', 'rb'), 'files': open('data.json', 'rb')}
|
files = {'files': ('document.pdf', open('document.pdf', 'rb'))}
|
||||||
bucket_id = upload_and_process(files)
|
bucket_id = upload_and_process(files)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Usage of Uploaded Files:**
|
||||||
|
```python
|
||||||
|
from g4f.client import Client
|
||||||
|
|
||||||
|
# Enable debug mode
|
||||||
|
import g4f.debug
|
||||||
|
g4f.debug.logging = True
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
|
||||||
|
# Upload example file
|
||||||
|
files = {'files': ('demo.docx', open('demo.docx', 'rb'))}
|
||||||
|
bucket_id = upload_and_process(files)
|
||||||
|
|
||||||
|
# Send request with file:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
[{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "Discribe this file."},
|
||||||
|
{"bucket_id": bucket_id}
|
||||||
|
]}],
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Output:**
|
||||||
|
```
|
||||||
|
This document is a demonstration of the DOCX Input plugin capabilities in the software ...
|
||||||
|
```
|
||||||
|
|
||||||
**Example Usage (JavaScript):**
|
**Example Usage (JavaScript):**
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import time
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession, ClientResponse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
|
|
@ -92,17 +92,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
# Step 1: Generate Authentication Token
|
# Step 1: Generate Authentication Token
|
||||||
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
|
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
|
||||||
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
|
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
|
||||||
if auth_response.status >= 400:
|
await raise_error(f"Failed to obtain authentication token", auth_response)
|
||||||
error_text = await auth_response.text()
|
auth_data = await auth_response.json()
|
||||||
raise ResponseError(f"Failed to obtain authentication token. Status: {auth_response.status}, Response: {error_text}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
auth_data = await auth_response.json()
|
|
||||||
except Exception as e:
|
|
||||||
error_text = await auth_response.text()
|
|
||||||
content_type = auth_response.headers.get('Content-Type', 'unknown')
|
|
||||||
raise ResponseError(f"Failed to parse auth response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
|
||||||
|
|
||||||
auth_token = auth_data.get("idToken")
|
auth_token = auth_data.get("idToken")
|
||||||
#refresh_token = auth_data.get("refreshToken")
|
#refresh_token = auth_data.get("refreshToken")
|
||||||
if not auth_token:
|
if not auth_token:
|
||||||
|
|
@ -118,17 +109,8 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
}
|
}
|
||||||
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
|
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
|
||||||
if response.status >= 400:
|
await raise_error(f"Failed to refresh token", response)
|
||||||
error_text = await response.text()
|
response_data = await response.json()
|
||||||
raise ResponseError(f"Failed to refresh token. Status: {response.status}, Response: {error_text}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_data = await response.json()
|
|
||||||
except Exception as e:
|
|
||||||
error_text = await response.text()
|
|
||||||
content_type = response.headers.get('Content-Type', 'unknown')
|
|
||||||
raise ResponseError(f"Failed to parse token refresh response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
|
||||||
|
|
||||||
return response_data.get("id_token"), response_data.get("refresh_token")
|
return response_data.get("id_token"), response_data.get("refresh_token")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -156,7 +138,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
guidance_scale: int = 7,
|
guidance_scale: int = 7,
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
aspect_ratio: str = "1:1",
|
aspect_ratio: str = None,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
|
|
@ -179,7 +161,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"images_num": str(n),
|
"images_num": str(n),
|
||||||
"cfg_scale": str(guidance_scale),
|
"cfg_scale": str(guidance_scale),
|
||||||
"steps": str(num_inference_steps),
|
"steps": str(num_inference_steps),
|
||||||
"aspect_ratio": aspect_ratio,
|
"aspect_ratio": "1:1" if aspect_ratio is None else aspect_ratio,
|
||||||
"seed": str(seed),
|
"seed": str(seed),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -188,45 +170,26 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
}
|
}
|
||||||
|
|
||||||
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
|
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
|
||||||
if image_response.status >= 400:
|
await raise_error(f"Failed to initiate image generation", image_response)
|
||||||
error_text = await image_response.text()
|
image_data = await image_response.json()
|
||||||
raise ResponseError(f"Failed to initiate image generation. Status: {image_response.status}, Response: {error_text}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_data = await image_response.json()
|
|
||||||
except Exception as e:
|
|
||||||
error_text = await image_response.text()
|
|
||||||
content_type = image_response.headers.get('Content-Type', 'unknown')
|
|
||||||
raise ResponseError(f"Failed to parse response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
|
||||||
|
|
||||||
record_id = image_data.get("record_id")
|
record_id = image_data.get("record_id")
|
||||||
if not record_id:
|
if not record_id:
|
||||||
raise ResponseError(f"Failed to initiate image generation: {image_data}")
|
raise ResponseError(f"Failed to initiate image generation: {image_data}")
|
||||||
|
|
||||||
# Step 3: Check Generation Status
|
# Step 3: Check Generation Status
|
||||||
status_url = cls.status_check_url.format(record_id=record_id)
|
status_url = cls.status_check_url.format(record_id=record_id)
|
||||||
counter = 4
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
last_status = None
|
last_status = None
|
||||||
while True:
|
while True:
|
||||||
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
|
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
|
||||||
if status_response.status >= 400:
|
await raise_error(f"Failed to check image generation status", status_response)
|
||||||
error_text = await status_response.text()
|
status_data = await status_response.json()
|
||||||
raise ResponseError(f"Failed to check image generation status. Status: {status_response.status}, Response: {error_text}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
status_data = await status_response.json()
|
|
||||||
except Exception as e:
|
|
||||||
error_text = await status_response.text()
|
|
||||||
content_type = status_response.headers.get('Content-Type', 'unknown')
|
|
||||||
raise ResponseError(f"Failed to parse status response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
|
|
||||||
|
|
||||||
status = status_data.get("status")
|
status = status_data.get("status")
|
||||||
|
|
||||||
if status == "DONE":
|
if status == "DONE":
|
||||||
image_urls = [image["url"] for image in status_data.get("response", [])]
|
image_urls = [image["url"] for image in status_data.get("response", [])]
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s")
|
yield Reasoning(label="Generated", status=f"{n} image in {duration:.2f}s" if n == 1 else f"{n} images in {duration:.2f}s")
|
||||||
yield ImageResponse(urls=image_urls, alt=prompt)
|
yield ImageResponse(urls=image_urls, alt=prompt)
|
||||||
return
|
return
|
||||||
elif status in ("IN_QUEUE", "IN_PROGRESS"):
|
elif status in ("IN_QUEUE", "IN_PROGRESS"):
|
||||||
|
|
@ -238,4 +201,11 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
yield Reasoning(label="Generating")
|
yield Reasoning(label="Generating")
|
||||||
await asyncio.sleep(2) # Poll every 2 seconds
|
await asyncio.sleep(2) # Poll every 2 seconds
|
||||||
else:
|
else:
|
||||||
raise ResponseError(f"Image generation failed with status: {status}")
|
raise ResponseError(f"Image generation failed with status: {status}")
|
||||||
|
|
||||||
|
async def raise_error(response: ClientResponse, message: str):
|
||||||
|
if response.ok:
|
||||||
|
return
|
||||||
|
error_text = await response.text()
|
||||||
|
content_type = response.headers.get('Content-Type', 'unknown')
|
||||||
|
raise ResponseError(f"{message}. Content-Type: {content_type}, Response: {error_text}")
|
||||||
|
|
@ -83,17 +83,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
actual_model = cls.get_model(model)
|
actual_model = cls.get_model(model)
|
||||||
|
|
||||||
# Use format_image_prompt for vision models when media is provided
|
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||||
if media is not None and len(media) > 0:
|
|
||||||
# For vision models, use format_image_prompt
|
|
||||||
if actual_model in cls.vision_models:
|
|
||||||
prompt = format_image_prompt(messages)
|
|
||||||
else:
|
|
||||||
# For non-vision models with images, still use the last user message
|
|
||||||
prompt = get_last_user_message(messages)
|
|
||||||
else:
|
|
||||||
# For text-only messages, use the standard format
|
|
||||||
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
|
||||||
|
|
||||||
# Determine the correct host for the model
|
# Determine the correct host for the model
|
||||||
if host is None:
|
if host is None:
|
||||||
|
|
@ -157,18 +147,16 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
if media is not None and len(media) > 0:
|
if media is not None and len(media) > 0:
|
||||||
conversation = Conversation(actual_model)
|
conversation = Conversation(actual_model)
|
||||||
|
|
||||||
# Add image if provided
|
# For each image in the media list (using merge_media to handle different formats)
|
||||||
if media is not None and len(media) > 0:
|
for image, image_name in merge_media(media, messages):
|
||||||
# For each image in the media list (using merge_media to handle different formats)
|
image_bytes = to_bytes(image)
|
||||||
for image, image_name in merge_media(media, messages):
|
form_data.extend([
|
||||||
image_bytes = to_bytes(image)
|
f'--{boundary}\r\n'
|
||||||
form_data.extend([
|
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
|
||||||
f'--{boundary}\r\n'
|
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
|
||||||
f'Content-Disposition: form-data; name="files"; filename="{image_name}"\r\n'
|
])
|
||||||
f'Content-Type: {is_accepted_format(image_bytes)}\r\n\r\n'
|
form_data.append(image_bytes.decode('latin1'))
|
||||||
])
|
form_data.append('\r\n')
|
||||||
form_data.append(image_bytes.decode('latin1'))
|
|
||||||
form_data.append('\r\n')
|
|
||||||
|
|
||||||
form_data.append(f'--{boundary}--\r\n')
|
form_data.append(f'--{boundary}--\r\n')
|
||||||
data = "".join(form_data).encode('latin1')
|
data = "".join(form_data).encode('latin1')
|
||||||
|
|
@ -182,11 +170,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
current_parent = None
|
current_parent = None
|
||||||
|
|
||||||
async for chunk in response.content:
|
async for line in response.content:
|
||||||
if not chunk:
|
|
||||||
continue
|
|
||||||
decoded = chunk.decode(errors="ignore")
|
|
||||||
for line in decoded.splitlines():
|
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -581,14 +581,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
}
|
}
|
||||||
current_messages.append(current_msg)
|
current_messages.append(current_msg)
|
||||||
|
|
||||||
if media is not None:
|
media = list(merge_media(media, messages))
|
||||||
|
if media:
|
||||||
current_messages[-1]['data'] = {
|
current_messages[-1]['data'] = {
|
||||||
"imagesData": [
|
"imagesData": [
|
||||||
{
|
{
|
||||||
"filePath": f"/{image_name}",
|
"filePath": f"/{image_name}",
|
||||||
"contents": to_data_uri(image)
|
"contents": to_data_uri(image)
|
||||||
}
|
}
|
||||||
for image, image_name in merge_media(media, messages)
|
for image, image_name in media
|
||||||
],
|
],
|
||||||
"fileText": "",
|
"fileText": "",
|
||||||
"title": ""
|
"title": ""
|
||||||
|
|
|
||||||
|
|
@ -518,11 +518,12 @@ async def async_read_and_download_urls(bucket_dir: Path, delete_files: bool = Fa
|
||||||
if urls:
|
if urls:
|
||||||
count = 0
|
count = 0
|
||||||
with open(os.path.join(bucket_dir, FILE_LIST), 'a') as f:
|
with open(os.path.join(bucket_dir, FILE_LIST), 'a') as f:
|
||||||
async for filename in download_urls(bucket_dir, urls):
|
for url in urls:
|
||||||
f.write(f"{filename}\n")
|
async for filename in download_urls(bucket_dir, **url):
|
||||||
if event_stream:
|
f.write(f"{filename}\n")
|
||||||
count += 1
|
if event_stream:
|
||||||
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
|
count += 1
|
||||||
|
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
|
||||||
|
|
||||||
def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]:
|
def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]:
|
||||||
size = 0
|
size = 0
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ class ToolHandler:
|
||||||
has_bucket = False
|
has_bucket = False
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if "content" in message and isinstance(message["content"], str):
|
if "content" in message and isinstance(message["content"], str):
|
||||||
new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
|
new_message_content = re.sub(r'{"bucket_id":\s*"([^"]*)"}', on_bucket, message["content"])
|
||||||
if new_message_content != message["content"]:
|
if new_message_content != message["content"]:
|
||||||
has_bucket = True
|
has_bucket = True
|
||||||
message["content"] = new_message_content
|
message["content"] = new_message_content
|
||||||
|
|
@ -97,29 +97,28 @@ class ToolHandler:
|
||||||
"""Process all tool calls and return updated messages and kwargs"""
|
"""Process all tool calls and return updated messages and kwargs"""
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return messages, {}
|
return messages, {}
|
||||||
|
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
messages = messages.copy()
|
messages = messages.copy()
|
||||||
sources = None
|
sources = None
|
||||||
|
|
||||||
for tool in tool_calls:
|
for tool in tool_calls:
|
||||||
if tool.get("type") != "function":
|
if tool.get("type") != "function":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
function_name = tool.get("function", {}).get("name")
|
function_name = tool.get("function", {}).get("name")
|
||||||
|
|
||||||
if function_name == TOOL_NAMES["SEARCH"]:
|
if function_name == TOOL_NAMES["SEARCH"]:
|
||||||
messages, sources = await ToolHandler.process_search_tool(messages, tool)
|
messages, sources = await ToolHandler.process_search_tool(messages, tool)
|
||||||
|
|
||||||
elif function_name == TOOL_NAMES["CONTINUE"]:
|
elif function_name == TOOL_NAMES["CONTINUE"]:
|
||||||
messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider)
|
messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider)
|
||||||
extra_kwargs.update(kwargs)
|
extra_kwargs.update(kwargs)
|
||||||
|
|
||||||
elif function_name == TOOL_NAMES["BUCKET"]:
|
elif function_name == TOOL_NAMES["BUCKET"]:
|
||||||
messages = ToolHandler.process_bucket_tool(messages, tool)
|
messages = ToolHandler.process_bucket_tool(messages, tool)
|
||||||
|
|
||||||
return messages, sources, extra_kwargs
|
|
||||||
|
|
||||||
|
return messages, sources, extra_kwargs
|
||||||
|
|
||||||
class AuthManager:
|
class AuthManager:
|
||||||
"""Handles API key management"""
|
"""Handles API key management"""
|
||||||
|
|
@ -128,13 +127,13 @@ class AuthManager:
|
||||||
def get_api_key_file(cls) -> Path:
|
def get_api_key_file(cls) -> Path:
|
||||||
"""Get the path to the API key file for a provider"""
|
"""Get the path to the API key file for a provider"""
|
||||||
return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_api_key(provider: Any) -> Optional[str]:
|
def load_api_key(provider: Any) -> Optional[str]:
|
||||||
"""Load API key from config file if needed"""
|
"""Load API key from config file if needed"""
|
||||||
if not getattr(provider, "needs_auth", False):
|
if not getattr(provider, "needs_auth", False):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
auth_file = AuthManager.get_api_key_file(provider)
|
auth_file = AuthManager.get_api_key_file(provider)
|
||||||
try:
|
try:
|
||||||
if auth_file.exists():
|
if auth_file.exists():
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional
|
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional, TypedDict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
@ -8,11 +7,6 @@ except ImportError:
|
||||||
class Image:
|
class Image:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
|
||||||
from typing import TypedDict
|
|
||||||
else:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from .providers.response import ResponseType
|
from .providers.response import ResponseType
|
||||||
|
|
||||||
SHA256 = NewType('sha_256_hash', str)
|
SHA256 = NewType('sha_256_hash', str)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue