mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
259 lines
9.5 KiB
Python
259 lines
9.5 KiB
Python
from __future__ import annotations
|
|
|
|
import random
|
|
import asyncio
|
|
import re
|
|
import json
|
|
from pathlib import Path
|
|
from aiohttp import ClientSession
|
|
from typing import AsyncIterator, Optional
|
|
|
|
from ..typing import AsyncResult, Messages
|
|
from ..image import ImageResponse
|
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|
from ..cookies import get_cookies_dir
|
|
|
|
from .. import debug
|
|
|
|
|
|
class BlackboxCreateAgent(AsyncGeneratorProvider, ProviderModelMixin):
|
|
url = "https://www.blackbox.ai"
|
|
api_endpoints = {
|
|
"llama-3.1-70b": "https://www.blackbox.ai/api/improve-prompt",
|
|
"flux": "https://www.blackbox.ai/api/image-generator"
|
|
}
|
|
|
|
working = True
|
|
supports_system_message = True
|
|
supports_message_history = True
|
|
|
|
default_model = 'llama-3.1-70b'
|
|
chat_models = [default_model]
|
|
image_models = ['flux']
|
|
models = [*chat_models, *image_models]
|
|
|
|
@classmethod
|
|
def _get_cache_file(cls) -> Path:
|
|
"""Returns the path to the cache file."""
|
|
dir = Path(get_cookies_dir())
|
|
dir.mkdir(exist_ok=True)
|
|
return dir / 'blackbox_create_agent.json'
|
|
|
|
@classmethod
|
|
def _load_cached_value(cls) -> str | None:
|
|
cache_file = cls._get_cache_file()
|
|
if cache_file.exists():
|
|
try:
|
|
with open(cache_file, 'r') as f:
|
|
data = json.load(f)
|
|
return data.get('validated_value')
|
|
except Exception as e:
|
|
debug.log(f"Error reading cache file: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def _save_cached_value(cls, value: str):
|
|
cache_file = cls._get_cache_file()
|
|
try:
|
|
with open(cache_file, 'w') as f:
|
|
json.dump({'validated_value': value}, f)
|
|
except Exception as e:
|
|
debug.log(f"Error writing to cache file: {e}")
|
|
|
|
@classmethod
|
|
async def fetch_validated(cls) -> Optional[str]:
|
|
"""
|
|
Asynchronously retrieves the validated value from cache or website.
|
|
|
|
:return: The validated value or None if retrieval fails.
|
|
"""
|
|
cached_value = cls._load_cached_value()
|
|
if cached_value:
|
|
return cached_value
|
|
|
|
js_file_pattern = r'static/chunks/\d{4}-[a-fA-F0-9]+\.js'
|
|
v_pattern = r'j\s*=\s*[\'"]([0-9a-fA-F-]{36})[\'"]'
|
|
|
|
def is_valid_context(text: str) -> bool:
|
|
"""Checks if the context is valid."""
|
|
return any(char + '=' in text for char in 'abcdefghijklmnopqrstuvwxyz')
|
|
|
|
async with ClientSession() as session:
|
|
try:
|
|
async with session.get(cls.url) as response:
|
|
if response.status != 200:
|
|
debug.log("Failed to download the page.")
|
|
return cached_value
|
|
|
|
page_content = await response.text()
|
|
js_files = re.findall(js_file_pattern, page_content)
|
|
|
|
for js_file in js_files:
|
|
js_url = f"{cls.url}/_next/{js_file}"
|
|
async with session.get(js_url) as js_response:
|
|
if js_response.status == 200:
|
|
js_content = await js_response.text()
|
|
for match in re.finditer(v_pattern, js_content):
|
|
start = max(0, match.start() - 50)
|
|
end = min(len(js_content), match.end() + 50)
|
|
context = js_content[start:end]
|
|
|
|
if is_valid_context(context):
|
|
validated_value = match.group(1)
|
|
cls._save_cached_value(validated_value)
|
|
return validated_value
|
|
except Exception as e:
|
|
debug.log(f"Error while retrieving validated_value: {e}")
|
|
|
|
return cached_value
|
|
|
|
@classmethod
|
|
async def create_async_generator(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
proxy: str = None,
|
|
prompt: str = None,
|
|
**kwargs
|
|
) -> AsyncIterator[str | ImageResponse]:
|
|
"""
|
|
Creates an async generator for text or image generation.
|
|
"""
|
|
if model in cls.chat_models:
|
|
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
|
|
yield text
|
|
elif model in cls.image_models:
|
|
prompt = messages[-1]['content']
|
|
async for image in cls._generate_image(model, prompt, proxy=proxy, **kwargs):
|
|
yield image
|
|
else:
|
|
raise ValueError(f"Model {model} not supported")
|
|
|
|
@classmethod
|
|
async def _generate_text(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
proxy: str = None,
|
|
max_retries: int = 3,
|
|
delay: int = 1,
|
|
max_tokens: int = None,
|
|
**kwargs
|
|
) -> AsyncIterator[str]:
|
|
headers = cls._get_headers()
|
|
|
|
for outer_attempt in range(2): # Add outer loop for retrying with a new key
|
|
validated_value = await cls.fetch_validated()
|
|
if not validated_value:
|
|
raise RuntimeError("Failed to get validated value")
|
|
|
|
async with ClientSession(headers=headers) as session:
|
|
api_endpoint = cls.api_endpoints[model]
|
|
|
|
data = {
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"validated": validated_value
|
|
}
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
async with session.post(api_endpoint, json=data, proxy=proxy) as response:
|
|
response.raise_for_status()
|
|
response_data = await response.json()
|
|
|
|
if response_data.get('status') == 200 and 'prompt' in response_data:
|
|
yield response_data['prompt']
|
|
return # Successful execution
|
|
else:
|
|
raise KeyError("Invalid response format or missing 'prompt' key")
|
|
except Exception as e:
|
|
if attempt == max_retries - 1:
|
|
if outer_attempt == 0: # If this is the first attempt with this key
|
|
# Remove the cached key and try to get a new one
|
|
cls._save_cached_value("")
|
|
debug.log("Invalid key, trying to get a new one...")
|
|
break # Exit the inner loop to get a new key
|
|
else:
|
|
raise RuntimeError(f"Error after all attempts: {str(e)}")
|
|
else:
|
|
wait_time = delay * (2 ** attempt) + random.uniform(0, 1)
|
|
debug.log(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...")
|
|
await asyncio.sleep(wait_time)
|
|
|
|
@classmethod
|
|
async def _generate_image(
|
|
cls,
|
|
model: str,
|
|
prompt: str,
|
|
proxy: str = None,
|
|
**kwargs
|
|
) -> AsyncIterator[ImageResponse]:
|
|
headers = {
|
|
**cls._get_headers()
|
|
}
|
|
|
|
api_endpoint = cls.api_endpoints[model]
|
|
|
|
async with ClientSession(headers=headers) as session:
|
|
data = {
|
|
"query": prompt
|
|
}
|
|
|
|
async with session.post(api_endpoint, json=data, proxy=proxy) as response:
|
|
response.raise_for_status()
|
|
response_data = await response.json()
|
|
|
|
if 'markdown' in response_data:
|
|
# Extract URL from markdown format: 
|
|
image_url = re.search(r'\!\[\]\((.*?)\)', response_data['markdown'])
|
|
if image_url:
|
|
yield ImageResponse(images=[image_url.group(1)], alt=prompt)
|
|
else:
|
|
raise ValueError("Could not extract image URL from markdown")
|
|
else:
|
|
raise KeyError("'markdown' key not found in response")
|
|
|
|
@staticmethod
|
|
def _get_headers() -> dict:
|
|
return {
|
|
'accept': '*/*',
|
|
'accept-language': 'en-US,en;q=0.9',
|
|
'authorization': f'Bearer 56c8eeff9971269d7a7e625ff88e8a83a34a556003a5c87c289ebe9a3d8a3d2c',
|
|
'content-type': 'application/json',
|
|
'origin': 'https://www.blackbox.ai',
|
|
'referer': 'https://www.blackbox.ai',
|
|
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
|
|
}
|
|
|
|
@classmethod
|
|
async def create_async(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
proxy: str = None,
|
|
**kwargs
|
|
) -> AsyncResult:
|
|
"""
|
|
Creates an async response for the provider.
|
|
|
|
Args:
|
|
model: The model to use
|
|
messages: The messages to process
|
|
proxy: Optional proxy to use
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
AsyncResult: The response from the provider
|
|
"""
|
|
if not model:
|
|
model = cls.default_model
|
|
if model in cls.chat_models:
|
|
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
|
|
return text
|
|
elif model in cls.image_models:
|
|
prompt = messages[-1]['content']
|
|
async for image in cls._generate_image(model, prompt, proxy=proxy, **kwargs):
|
|
return image
|
|
else:
|
|
raise ValueError(f"Model {model} not supported")
|