mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Add filessupport, scrape and refine your data
Remove Webdriver usages Add continue messages for other providers
This commit is contained in:
parent
90360ccfa6
commit
7893a0835e
33 changed files with 1155 additions and 559 deletions
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
from unittest.mock import MagicMock
|
||||
from g4f.errors import MissingRequirementsError
|
||||
try:
|
||||
from g4f.gui.server.backend import Backend_Api
|
||||
from g4f.gui.server.backend_api import Backend_Api
|
||||
has_requirements = True
|
||||
except:
|
||||
has_requirements = False
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from ..typing import AsyncResult, Messages, ImagesType
|
|||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..image import ImageResponse, to_data_uri
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..web_search import get_search_message
|
||||
from ..tools.web_search import get_search_message
|
||||
from .helper import format_prompt
|
||||
|
||||
from .. import debug
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ import json
|
|||
|
||||
from ..typing import AsyncResult, Messages, Cookies
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
|
||||
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies, DEFAULT_HEADERS, has_nodriver, has_curl_cffi
|
||||
from ..errors import ResponseStatusError
|
||||
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
||||
from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
|
||||
from ..providers.response import FinishReason
|
||||
from ..errors import ResponseStatusError, ModelNotFoundError
|
||||
|
||||
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Cloudflare AI"
|
||||
|
|
@ -70,7 +72,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
|
||||
else:
|
||||
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
|
||||
model = cls.get_model(model)
|
||||
try:
|
||||
model = cls.get_model(model)
|
||||
except ModelNotFoundError:
|
||||
pass
|
||||
data = {
|
||||
"messages": messages,
|
||||
"lora": None,
|
||||
|
|
@ -89,6 +94,7 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
except ResponseStatusError:
|
||||
cls._args = None
|
||||
raise
|
||||
reason = None
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b'data: '):
|
||||
if line == b'data: [DONE]':
|
||||
|
|
@ -97,5 +103,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
content = json.loads(line[6:].decode())
|
||||
if content.get("response") and content.get("response") != '</s>':
|
||||
yield content['response']
|
||||
reason = "max_tokens"
|
||||
elif content.get("response") == '':
|
||||
reason = "stop"
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
if reason is not None:
|
||||
yield FinishReason(reason)
|
||||
|
|
@ -15,7 +15,6 @@ except ImportError:
|
|||
|
||||
from ..helper import get_connector
|
||||
from ...errors import MissingRequirementsError, RateLimitError
|
||||
from ...webdriver import WebDriver, get_driver_cookies, get_browser
|
||||
|
||||
BING_URL = "https://www.bing.com"
|
||||
TIMEOUT_LOGIN = 1200
|
||||
|
|
@ -31,39 +30,6 @@ BAD_IMAGES = [
|
|||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
|
||||
def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
|
||||
"""
|
||||
Waits for the user to log in within a given timeout period.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): Webdriver for browser automation.
|
||||
timeout (int): Maximum waiting time in seconds.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the login process exceeds the timeout.
|
||||
"""
|
||||
driver.get(f"{BING_URL}/")
|
||||
start_time = time.time()
|
||||
while not driver.get_cookie("_U"):
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("Timeout error")
|
||||
time.sleep(0.5)
|
||||
|
||||
def get_cookies_from_browser(proxy: str = None) -> dict[str, str]:
|
||||
"""
|
||||
Retrieves cookies from the browser using webdriver.
|
||||
|
||||
Args:
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Retrieved cookies.
|
||||
"""
|
||||
with get_browser(proxy=proxy) as driver:
|
||||
wait_for_login(driver)
|
||||
time.sleep(1)
|
||||
return get_driver_cookies(driver)
|
||||
|
||||
def create_session(cookies: Dict[str, str], proxy: str = None, connector: BaseConnector = None) -> ClientSession:
|
||||
"""
|
||||
Creates a new client session with specified cookies and headers.
|
||||
|
|
|
|||
|
|
@ -102,22 +102,15 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if "config" in model_data and "model_type" in model_data["config"]:
|
||||
model_type = model_data["config"]["model_type"]
|
||||
debug.log(f"Model type: {model_type}")
|
||||
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
|
||||
inputs = format_prompt(messages, do_continue=do_continue)
|
||||
elif model_type in ("mistral"):
|
||||
inputs = format_prompt_mistral(messages, do_continue)
|
||||
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
|
||||
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
|
||||
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
|
||||
inputs = format_prompt_custom(messages, eos_token, do_continue)
|
||||
elif eos_token == "<|im_end|>":
|
||||
inputs = format_prompt_qwen(messages, do_continue)
|
||||
elif eos_token == "<|eot_id|>":
|
||||
inputs = format_prompt_llama(messages, do_continue)
|
||||
inputs = get_inputs(messages, model_data, model_type, do_continue)
|
||||
debug.log(f"Inputs len: {len(inputs)}")
|
||||
if len(inputs) > 4096:
|
||||
if len(messages) > 6:
|
||||
messages = messages[:3] + messages[-3:]
|
||||
else:
|
||||
inputs = format_prompt(messages, do_continue=do_continue)
|
||||
else:
|
||||
inputs = format_prompt(messages, do_continue=do_continue)
|
||||
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
|
||||
inputs = get_inputs(messages, model_data, model_type, do_continue)
|
||||
debug.log(f"New len: {len(inputs)}")
|
||||
if model_type == "gpt2" and max_new_tokens >= 1024:
|
||||
params["max_new_tokens"] = 512
|
||||
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
||||
|
|
@ -187,4 +180,23 @@ def format_prompt_custom(messages: Messages, end_token: str = "</s>", do_continu
|
|||
]) + ("" if do_continue else "<|assistant|>\n")
|
||||
if do_continue:
|
||||
return prompt[:-len(end_token + "\n")]
|
||||
return prompt
|
||||
return prompt
|
||||
|
||||
def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str:
|
||||
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
|
||||
inputs = format_prompt(messages, do_continue=do_continue)
|
||||
elif model_type in ("mistral"):
|
||||
inputs = format_prompt_mistral(messages, do_continue)
|
||||
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
|
||||
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
|
||||
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
|
||||
inputs = format_prompt_custom(messages, eos_token, do_continue)
|
||||
elif eos_token == "<|im_end|>":
|
||||
inputs = format_prompt_qwen(messages, do_continue)
|
||||
elif eos_token == "<|eot_id|>":
|
||||
inputs = format_prompt_llama(messages, do_continue)
|
||||
else:
|
||||
inputs = format_prompt(messages, do_continue=do_continue)
|
||||
else:
|
||||
inputs = format_prompt(messages, do_continue=do_continue)
|
||||
return inputs
|
||||
|
|
@ -404,7 +404,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
data["conversation_id"] = conversation.conversation_id
|
||||
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
||||
if action != "continue":
|
||||
data["parent_message_id"] = conversation.parent_message_id
|
||||
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
|
||||
conversation.parent_message_id = None
|
||||
messages = messages if conversation_id is None else [messages[-1]]
|
||||
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
|
||||
|
|
@ -604,7 +604,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"api_key": cls._api_key,
|
||||
"proof_token": RequestConfig.proof_token,
|
||||
"cookies": RequestConfig.cookies,
|
||||
"headers": RequestConfig.headers
|
||||
})
|
||||
|
||||
@classmethod
|
||||
|
|
@ -636,6 +635,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
page = await browser.get(cls.url)
|
||||
user_agent = await page.evaluate("window.navigator.userAgent")
|
||||
await page.select("#prompt-textarea", 240)
|
||||
await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'")
|
||||
await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()")
|
||||
while True:
|
||||
if cls._api_key is not None:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import time
|
|||
from ...typing import CreateResult, Messages
|
||||
from ..base_provider import AbstractProvider
|
||||
from ..helper import format_prompt
|
||||
from ...webdriver import WebDriver, WebDriverSession, element_send_text
|
||||
|
||||
models = {
|
||||
"meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
|
||||
|
|
@ -22,7 +21,7 @@ models = {
|
|||
|
||||
class Poe(AbstractProvider):
|
||||
url = "https://poe.com"
|
||||
working = True
|
||||
working = False
|
||||
needs_auth = True
|
||||
supports_stream = True
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import time
|
|||
from ...typing import CreateResult, Messages
|
||||
from ..base_provider import AbstractProvider
|
||||
from ..helper import format_prompt
|
||||
from ...webdriver import WebDriver, WebDriverSession, element_send_text
|
||||
|
||||
models = {
|
||||
"theb-ai": "TheB.AI",
|
||||
|
|
@ -34,7 +33,7 @@ models = {
|
|||
class Theb(AbstractProvider):
|
||||
label = "TheB.AI"
|
||||
url = "https://beta.theb.ai"
|
||||
working = True
|
||||
working = False
|
||||
supports_stream = True
|
||||
|
||||
models = models.keys()
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@ from aiohttp import ClientSession
|
|||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ..base_provider import AsyncGeneratorProvider
|
||||
from ...requests import get_args_from_browser
|
||||
from ...webdriver import WebDriver
|
||||
|
||||
class Aura(AsyncGeneratorProvider):
|
||||
url = "https://openchat.team"
|
||||
|
|
@ -19,7 +17,7 @@ class Aura(AsyncGeneratorProvider):
|
|||
proxy: str = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 8192,
|
||||
webdriver: WebDriver = None,
|
||||
webdriver = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
args = get_args_from_browser(cls.url, webdriver, proxy)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import time, json
|
|||
from ...typing import CreateResult, Messages
|
||||
from ..base_provider import AbstractProvider
|
||||
from ..helper import format_prompt
|
||||
from ...webdriver import WebDriver, WebDriverSession, bypass_cloudflare
|
||||
|
||||
class MyShell(AbstractProvider):
|
||||
url = "https://app.myshell.ai/chat"
|
||||
|
|
@ -21,7 +20,7 @@ class MyShell(AbstractProvider):
|
|||
stream: bool,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
webdriver: WebDriver = None,
|
||||
webdriver = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
with WebDriverSession(webdriver, "", proxy=proxy) as driver:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ except ImportError:
|
|||
from ...typing import CreateResult, Messages
|
||||
from ..base_provider import AbstractProvider
|
||||
from ..helper import format_prompt
|
||||
from ...webdriver import WebDriver, WebDriverSession, element_send_text
|
||||
|
||||
class PerplexityAi(AbstractProvider):
|
||||
url = "https://www.perplexity.ai"
|
||||
|
|
@ -28,7 +27,7 @@ class PerplexityAi(AbstractProvider):
|
|||
stream: bool,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
webdriver: WebDriver = None,
|
||||
webdriver = None,
|
||||
virtual_display: bool = True,
|
||||
copilot: bool = False,
|
||||
**kwargs
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from urllib.parse import quote
|
|||
from ...typing import CreateResult, Messages
|
||||
from ..base_provider import AbstractProvider
|
||||
from ..helper import format_prompt
|
||||
from ...webdriver import WebDriver, WebDriverSession
|
||||
|
||||
class Phind(AbstractProvider):
|
||||
url = "https://www.phind.com"
|
||||
|
|
@ -22,7 +21,7 @@ class Phind(AbstractProvider):
|
|||
stream: bool,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
webdriver: WebDriver = None,
|
||||
webdriver = None,
|
||||
creative_mode: bool = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import time, json, time
|
|||
|
||||
from ...typing import CreateResult, Messages
|
||||
from ..base_provider import AbstractProvider
|
||||
from ...webdriver import WebDriver, WebDriverSession
|
||||
|
||||
class TalkAi(AbstractProvider):
|
||||
url = "https://talkai.info"
|
||||
|
|
@ -19,7 +18,7 @@ class TalkAi(AbstractProvider):
|
|||
messages: Messages,
|
||||
stream: bool,
|
||||
proxy: str = None,
|
||||
webdriver: WebDriver = None,
|
||||
webdriver = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
with WebDriverSession(webdriver, "", virtual_display=True, proxy=proxy) as driver:
|
||||
|
|
|
|||
|
|
@ -41,11 +41,12 @@ from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthErr
|
|||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||
from g4f.Provider import ProviderType, ProviderUtils, __providers__
|
||||
from g4f.gui import get_gui_app
|
||||
from g4f.tools.files import supports_filename, get_streaming
|
||||
from .stubs import (
|
||||
ChatCompletionsConfig, ImageGenerationConfig,
|
||||
ProviderResponseModel, ModelResponseModel,
|
||||
ErrorResponseModel, ProviderResponseDetailModel,
|
||||
FileResponseModel, Annotated
|
||||
FileResponseModel, UploadResponseModel, Annotated
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -424,6 +425,40 @@ class Api:
|
|||
read_cookie_files()
|
||||
return response_data
|
||||
|
||||
@self.app.get("/v1/files/{bucket_id}", responses={
|
||||
HTTP_200_OK: {"content": {
|
||||
"text/event-stream": {"schema": {"type": "string"}},
|
||||
"text/plain": {"schema": {"type": "string"}},
|
||||
}},
|
||||
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||
})
|
||||
def read_files(request: Request, bucket_id: str, delete_files: bool = True, refine_chunks_with_spacy: bool = False):
|
||||
bucket_dir = os.path.join(get_cookies_dir(), bucket_id)
|
||||
event_stream = "text/event-stream" in request.headers.get("accept", "")
|
||||
if not os.path.isdir(bucket_dir):
|
||||
return ErrorResponse.from_message("Bucket dir not found", 404)
|
||||
return StreamingResponse(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), media_type="text/plain")
|
||||
|
||||
@self.app.post("/v1/files/{bucket_id}", responses={
|
||||
HTTP_200_OK: {"model": UploadResponseModel}
|
||||
})
|
||||
def upload_files(bucket_id: str, files: List[UploadFile]):
|
||||
bucket_dir = os.path.join(get_cookies_dir(), bucket_id)
|
||||
os.makedirs(bucket_dir, exist_ok=True)
|
||||
filenames = []
|
||||
for file in files:
|
||||
try:
|
||||
filename = os.path.basename(file.filename)
|
||||
if file and supports_filename(filename):
|
||||
with open(os.path.join(bucket_dir, filename), 'wb') as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
filenames.append(filename)
|
||||
finally:
|
||||
file.file.close()
|
||||
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
|
||||
[f.write(f"{filename}\n") for filename in filenames]
|
||||
return {"bucket_id": bucket_id, "url": f"/v1/files/{bucket_id}", "files": filenames}
|
||||
|
||||
@self.app.get("/v1/synthesize/{provider}", responses={
|
||||
HTTP_200_OK: {"content": {"audio/*": {}}},
|
||||
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||
|
|
|
|||
|
|
@ -66,6 +66,10 @@ class ModelResponseModel(BaseModel):
|
|||
created: int
|
||||
owned_by: Optional[str]
|
||||
|
||||
class UploadResponseModel(BaseModel):
|
||||
bucket_id: str
|
||||
url: str
|
||||
|
||||
class ErrorResponseModel(BaseModel):
|
||||
error: ErrorResponseMessageModel
|
||||
model: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import string
|
|||
import asyncio
|
||||
import aiohttp
|
||||
import base64
|
||||
import json
|
||||
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
|
||||
|
||||
from ..image import ImageResponse, copy_images, images_dir
|
||||
|
|
@ -17,13 +16,13 @@ from ..providers.response import ResponseType, FinishReason, BaseConversation, S
|
|||
from ..errors import NoImageResponseError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.asyncio import to_sync_generator, async_generator_to_list
|
||||
from ..web_search import get_search_message, do_search
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||
from .. import debug
|
||||
|
||||
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
|
|
@ -38,47 +37,6 @@ except NameError:
|
|||
except StopAsyncIteration:
|
||||
raise StopIteration
|
||||
|
||||
def validate_arguments(data: dict) -> dict:
|
||||
if "arguments" in data:
|
||||
if isinstance(data["arguments"], str):
|
||||
data["arguments"] = json.loads(data["arguments"])
|
||||
if not isinstance(data["arguments"], dict):
|
||||
raise ValueError("Tool function arguments must be a dictionary or a json string")
|
||||
else:
|
||||
return filter_none(**data["arguments"])
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if tool.get("type") == "function":
|
||||
if tool.get("function", {}).get("name") == "search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
messages = messages.copy()
|
||||
messages[-1]["content"] = await do_search(
|
||||
messages[-1]["content"],
|
||||
**tool["function"]["arguments"]
|
||||
)
|
||||
response = async_iter_callback(model=model, messages=messages, **kwargs)
|
||||
if not hasattr(response, "__aiter__"):
|
||||
response = to_async_iterator(response)
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
def iter_run_tools(iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if tool.get("type") == "function":
|
||||
if tool.get("function", {}).get("name") == "search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
messages[-1]["content"] = get_search_message(
|
||||
messages[-1]["content"],
|
||||
raise_search_exceptions=True,
|
||||
**tool["function"]["arguments"]
|
||||
)
|
||||
return iter_callback(model=model, messages=messages, **kwargs)
|
||||
|
||||
# Synchronous iter_response function
|
||||
def iter_response(
|
||||
response: Union[Iterator[Union[str, ResponseType]]],
|
||||
|
|
@ -131,7 +89,8 @@ def iter_response(
|
|||
break
|
||||
|
||||
idx += 1
|
||||
|
||||
if usage is None:
|
||||
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx).get_dict()
|
||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||
|
||||
if stream:
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ def read_cookie_files(dirPath: str = None):
|
|||
except json.JSONDecodeError:
|
||||
# Error: not a json file!
|
||||
continue
|
||||
if not isinstance(cookieFile, list):
|
||||
if not isinstance(cookieFile, list) or not isinstance(cookieFile[0], dict) or "domain" not in cookieFile[0]:
|
||||
continue
|
||||
debug.log(f"Read cookie file: {path}")
|
||||
new_cookies = {}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from ..errors import MissingRequirementsError
|
||||
|
||||
try:
|
||||
from .server.app import app
|
||||
from .server.website import Website
|
||||
from .server.backend import Backend_Api
|
||||
from .server.backend_api import Backend_Api
|
||||
from .server.app import create_app
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
|
|
@ -11,6 +11,7 @@ except ImportError as e:
|
|||
def get_gui_app():
|
||||
if import_error is not None:
|
||||
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
|
||||
app = create_app()
|
||||
|
||||
site = Website(app)
|
||||
for route in site.routes:
|
||||
|
|
@ -36,7 +37,7 @@ def run_gui(host: str = '0.0.0.0', port: int = 8080, debug: bool = False) -> Non
|
|||
'debug': debug
|
||||
}
|
||||
|
||||
get_gui_app()
|
||||
app = get_gui_app()
|
||||
|
||||
print(f"Running on port {config['port']}")
|
||||
app.run(**config)
|
||||
|
|
|
|||
|
|
@ -136,6 +136,11 @@
|
|||
<input type="checkbox" id="download_images" checked/>
|
||||
<label for="download_images" class="toogle" title="Download and save generated images to /generated_images"></label>
|
||||
</div>
|
||||
<div class="field">
|
||||
<span class="label">Refine files with spaCy</span>
|
||||
<input type="checkbox" id="refine" checked/>
|
||||
<label for="refine" class="toogle" title=""></label>
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="message-input-height" class="label" title="">Input max. height</label>
|
||||
<input type="number" id="message-input-height" value="200"/>
|
||||
|
|
@ -258,7 +263,7 @@
|
|||
<i class="fa-solid fa-camera"></i>
|
||||
</label>
|
||||
<label class="file-label" for="file">
|
||||
<input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md" required/>
|
||||
<input type="file" id="file" name="file" accept=".txt, .html, .xml, .json, .js, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md, .pdf, .docx, .odt, .epub, .xlsx, .zip" required multiple/>
|
||||
<i class="fa-solid fa-paperclip"></i>
|
||||
</label>
|
||||
<label class="micro-label" for="micro">
|
||||
|
|
|
|||
|
|
@ -423,8 +423,10 @@ body:not(.white) a:visited{
|
|||
.message .count .fa-print.clicked,
|
||||
.message .count .fa-rotate.clicked,
|
||||
.message .count .fa-volume-high.active,
|
||||
.message .count .fa-file-export.clicked,
|
||||
.message .continue_button.clicked,
|
||||
.message .regenerate_button.clicked {
|
||||
.message .regenerate_button.clicked,
|
||||
.message .options_button.clicked {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
|
|
@ -530,7 +532,7 @@ body:not(.white) a:visited{
|
|||
right: 0;
|
||||
}
|
||||
|
||||
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button {
|
||||
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button, button.options_button {
|
||||
backdrop-filter: blur(20px);
|
||||
-webkit-backdrop-filter: blur(20px);
|
||||
background-color: var(--blur-bg);
|
||||
|
|
@ -553,11 +555,11 @@ body:not(.white) a:visited{
|
|||
right: auto;
|
||||
}
|
||||
|
||||
.toolbar .regenerate span, .regenerate_button span, .continue_button span {
|
||||
.toolbar .regenerate span, .regenerate_button span, .continue_button span, .options_button div {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.regenerate_button span, .continue_button span {
|
||||
.regenerate_button span, .continue_button span, .options_button div {
|
||||
position: absolute;
|
||||
height: 20px;
|
||||
width: 100px;
|
||||
|
|
@ -568,6 +570,27 @@ body:not(.white) a:visited{
|
|||
padding: 6px;
|
||||
}
|
||||
|
||||
.options_button div {
|
||||
display: none;
|
||||
flex-direction: row;
|
||||
height: 36px;
|
||||
width: 120px;
|
||||
margin-right: 130px;
|
||||
z-index: 1005;
|
||||
padding: 10px;
|
||||
margin-top: -4px;
|
||||
}
|
||||
|
||||
.options_button div span{
|
||||
height: 20px;
|
||||
width: 22px;
|
||||
|
||||
}
|
||||
|
||||
.options_button:hover div {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.regenerate_button:hover span, .continue_button:hover span {
|
||||
display: block;
|
||||
transition: all 0.3s;
|
||||
|
|
@ -786,7 +809,7 @@ select {
|
|||
border-radius: 25px;
|
||||
}
|
||||
|
||||
.buttons button, button.regenerate_button, button.continue_button {
|
||||
.buttons button, button.regenerate_button, button.continue_button, button.options_button {
|
||||
border-radius: 8px;
|
||||
backdrop-filter: blur(20px);
|
||||
cursor: pointer;
|
||||
|
|
@ -796,7 +819,7 @@ select {
|
|||
padding: 8px;
|
||||
}
|
||||
|
||||
button.regenerate_button, button.continue_button {
|
||||
button.regenerate_button, button.continue_button, button.options_button {
|
||||
display: flex;
|
||||
margin-top: -8px;
|
||||
margin-left: 2px;
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ if (window.markdownit) {
|
|||
function filter_message(text) {
|
||||
return text.replaceAll(
|
||||
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
|
||||
)
|
||||
).replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
|
||||
}
|
||||
|
||||
function fallback_clipboard (text) {
|
||||
|
|
@ -204,7 +204,8 @@ const register_message_buttons = async () => {
|
|||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
let message_el = get_message_el(el);
|
||||
const copyText = await get_message(window.conversation_id, message_el.dataset.index);
|
||||
let response = await fetch(message_el.dataset.object_url);
|
||||
let copyText = await response.text();
|
||||
try {
|
||||
if (!navigator.clipboard) {
|
||||
throw new Error("navigator.clipboard: Clipboard API unavailable.");
|
||||
|
|
@ -221,6 +222,24 @@ const register_message_buttons = async () => {
|
|||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
let message_el = get_message_el(el);
|
||||
const elem = window.document.createElement('a');
|
||||
let filename = `chat ${new Date().toLocaleString()}.md`.replaceAll(":", "-");
|
||||
elem.href = message_el.dataset.object_url;
|
||||
elem.download = filename;
|
||||
document.body.appendChild(elem);
|
||||
elem.click();
|
||||
document.body.removeChild(elem);
|
||||
el.classList.add("clicked");
|
||||
setTimeout(() => el.classList.remove("clicked"), 1000);
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
|
|
@ -362,11 +381,6 @@ const handle_ask = async () => {
|
|||
</div>
|
||||
<div class="count">
|
||||
${count_words_and_tokens(message, get_selected_model()?.value)}
|
||||
<i class="fa-solid fa-volume-high"></i>
|
||||
<i class="fa-regular fa-clipboard"></i>
|
||||
<a><i class="fa-brands fa-whatsapp"></i></a>
|
||||
<i class="fa-solid fa-print"></i>
|
||||
<i class="fa-solid fa-rotate"></i>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -484,11 +498,15 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) =>
|
|||
}
|
||||
}
|
||||
|
||||
messages.forEach((new_message) => {
|
||||
messages.forEach((new_message, i) => {
|
||||
// Copy message first
|
||||
new_message = { ...new_message };
|
||||
// Include last message, if do_continue
|
||||
if (i + 1 == messages.length && do_continue) {
|
||||
delete new_message.regenerate;
|
||||
}
|
||||
// Include only not regenerated messages
|
||||
if (new_message && !new_message.regenerate) {
|
||||
// Copy message first
|
||||
new_message = { ...new_message };
|
||||
// Remove generated images from history
|
||||
new_message.content = filter_message(new_message.content);
|
||||
// Remove internal fields
|
||||
|
|
@ -707,8 +725,8 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
message_storage[message_id] = "";
|
||||
stop_generating.classList.remove("stop_generating-hidden");
|
||||
|
||||
if (message_index == -1 && !regenerate) {
|
||||
await scroll_to_bottom();
|
||||
if (message_index == -1) {
|
||||
await lazy_scroll_to_bottom();
|
||||
}
|
||||
|
||||
let count_total = message_box.querySelector('.count_total');
|
||||
|
|
@ -750,9 +768,10 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
inner: content_el.querySelector('.content_inner'),
|
||||
count: content_el.querySelector('.count'),
|
||||
update_timeouts: [],
|
||||
message_index: message_index,
|
||||
}
|
||||
if (message_index == -1 && !regenerate) {
|
||||
await scroll_to_bottom();
|
||||
if (message_index == -1) {
|
||||
await lazy_scroll_to_bottom();
|
||||
}
|
||||
try {
|
||||
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
|
||||
|
|
@ -813,7 +832,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
let cursorDiv = message_el.querySelector(".cursor");
|
||||
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
|
||||
if (message_index == -1) {
|
||||
await scroll_to_bottom();
|
||||
await lazy_scroll_to_bottom();
|
||||
}
|
||||
await safe_remove_cancel_button();
|
||||
await register_message_buttons();
|
||||
|
|
@ -826,6 +845,12 @@ async function scroll_to_bottom() {
|
|||
message_box.scrollTop = message_box.scrollHeight;
|
||||
}
|
||||
|
||||
async function lazy_scroll_to_bottom() {
|
||||
if (message_box.scrollHeight - message_box.scrollTop < 2 * message_box.clientHeight) {
|
||||
await scroll_to_bottom();
|
||||
}
|
||||
}
|
||||
|
||||
const clear_conversations = async () => {
|
||||
const elements = box_conversations.childNodes;
|
||||
let index = elements.length;
|
||||
|
|
@ -971,7 +996,23 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
} else {
|
||||
buffer = "";
|
||||
}
|
||||
buffer += item.content;
|
||||
buffer = buffer.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
|
||||
let lines = buffer.trim().split("\n");
|
||||
let lastLine = lines[lines.length - 1];
|
||||
let newContent = item.content;
|
||||
if (newContent.startsWith("```\n")) {
|
||||
newContent = item.content.substring(4);
|
||||
}
|
||||
if (newContent.startsWith(lastLine)) {
|
||||
newContent = newContent.substring(lastLine.length);
|
||||
} else {
|
||||
let words = buffer.trim().split(" ");
|
||||
let lastWord = words[words.length - 1];
|
||||
if (newContent.startsWith(lastWord)) {
|
||||
newContent = newContent.substring(lastWord.length);
|
||||
}
|
||||
}
|
||||
buffer += newContent;
|
||||
last_model = item.provider?.model;
|
||||
providers.push(item.provider?.name);
|
||||
let next_i = parseInt(i) + 1;
|
||||
|
|
@ -993,28 +1034,74 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
synthesize_params = (new URLSearchParams(synthesize_params)).toString();
|
||||
let synthesize_url = `/backend-api/v2/synthesize/${synthesize_provider}?${synthesize_params}`;
|
||||
|
||||
const file = new File([buffer], 'message.md', {type: 'text/plain'});
|
||||
const objectUrl = URL.createObjectURL(file);
|
||||
|
||||
let add_buttons = [];
|
||||
// Always add regenerate button
|
||||
add_buttons.push(`<button class="regenerate_button">
|
||||
<span>Regenerate</span>
|
||||
<i class="fa-solid fa-rotate"></i>
|
||||
</button>`);
|
||||
// Add continue button if possible
|
||||
actions = ["variant"]
|
||||
if (item.finish && item.finish.actions) {
|
||||
item.finish.actions.forEach((action) => {
|
||||
if (action == "continue") {
|
||||
if (messages.length >= i - 1) {
|
||||
add_buttons.push(`<button class="continue_button">
|
||||
<span>Continue</span>
|
||||
<i class="fa-solid fa-wand-magic-sparkles"></i>
|
||||
</button>`);
|
||||
}
|
||||
actions = item.finish.actions
|
||||
}
|
||||
if (!("continue" in actions)) {
|
||||
let reason = "stop";
|
||||
// Read finish reason from conversation
|
||||
if (item.finish && item.finish.reason) {
|
||||
reason = item.finish.reason;
|
||||
}
|
||||
let lines = buffer.trim().split("\n");
|
||||
let lastLine = lines[lines.length - 1];
|
||||
// Has a stop or error token at the end
|
||||
if (lastLine.endsWith("[aborted]") || lastLine.endsWith("[error]")) {
|
||||
reason = "error";
|
||||
// Has an even number of start or end code tags
|
||||
} else if (buffer.split("```").length - 1 % 2 === 1) {
|
||||
reason = "error";
|
||||
// Has a end token at the end
|
||||
} else if (lastLine.endsWith("```") || lastLine.endsWith(".") || lastLine.endsWith("?") || lastLine.endsWith("!")
|
||||
|| lastLine.endsWith('"') || lastLine.endsWith("'") || lastLine.endsWith(")")
|
||||
|| lastLine.endsWith(">") || lastLine.endsWith("]") || lastLine.endsWith("}") ) {
|
||||
reason = "stop"
|
||||
} else {
|
||||
// Has an emoji at the end
|
||||
const regex = /\p{Emoji}$/u;
|
||||
if (regex.test(lastLine)) {
|
||||
reason = "stop"
|
||||
}
|
||||
});
|
||||
}
|
||||
if (reason == "max_tokens" || reason == "error") {
|
||||
actions.push("continue")
|
||||
}
|
||||
}
|
||||
|
||||
add_buttons.push(`<button class="options_button">
|
||||
<div>
|
||||
<span><i class="fa-brands fa-whatsapp"></i></span>
|
||||
<span><i class="fa-solid fa-volume-high"></i></i></span>
|
||||
<span><i class="fa-solid fa-print"></i></span>
|
||||
<span><i class="fa-solid fa-file-export"></i></span>
|
||||
<span><i class="fa-regular fa-clipboard"></i></span>
|
||||
</div>
|
||||
<i class="fa-solid fa-plus"></i>
|
||||
</button>`);
|
||||
|
||||
if (actions.includes("variant")) {
|
||||
add_buttons.push(`<button class="regenerate_button">
|
||||
<span>Regenerate</span>
|
||||
<i class="fa-solid fa-rotate"></i>
|
||||
</button>`);
|
||||
}
|
||||
if (actions.includes("continue")) {
|
||||
if (messages.length >= i - 1) {
|
||||
add_buttons.push(`<button class="continue_button">
|
||||
<span>Continue</span>
|
||||
<i class="fa-solid fa-wand-magic-sparkles"></i>
|
||||
</button>`);
|
||||
}
|
||||
}
|
||||
|
||||
elements.push(`
|
||||
<div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}" data-synthesize_url="${synthesize_url}">
|
||||
<div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}" data-object_url="${objectUrl}" data-synthesize_url="${synthesize_url}">
|
||||
<div class="${item.role}">
|
||||
${item.role == "assistant" ? gpt_image : user_image}
|
||||
<i class="fa-solid fa-xmark"></i>
|
||||
|
|
@ -1028,12 +1115,6 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
<div class="content_inner">${markdown_render(buffer)}</div>
|
||||
<div class="count">
|
||||
${count_words_and_tokens(buffer, next_provider?.model)}
|
||||
<span>
|
||||
<i class="fa-solid fa-volume-high"></i>
|
||||
<i class="fa-regular fa-clipboard"></i>
|
||||
<a><i class="fa-brands fa-whatsapp"></i></a>
|
||||
<i class="fa-solid fa-print"></i>
|
||||
</span>
|
||||
${add_buttons.join("")}
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -1444,11 +1525,8 @@ function update_message(content_map, message_id, content = null) {
|
|||
content_map.inner.innerHTML = html;
|
||||
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
|
||||
highlight(content_map.inner);
|
||||
if (!content_map.container.classList.contains("regenerate")) {
|
||||
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 200) {
|
||||
window.scrollTo(0, 0);
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
}
|
||||
if (content_map.message_index == -1) {
|
||||
lazy_scroll_to_bottom();
|
||||
}
|
||||
content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId));
|
||||
content_map.update_timeouts = [];
|
||||
|
|
@ -1711,19 +1789,76 @@ async function upload_cookies() {
|
|||
fileInput.value = "";
|
||||
}
|
||||
|
||||
function formatFileSize(bytes) {
|
||||
const units = ['B', 'KB', 'MB', 'GB'];
|
||||
let unitIndex = 0;
|
||||
while (bytes >= 1024 && unitIndex < units.length - 1) {
|
||||
bytes /= 1024;
|
||||
unitIndex++;
|
||||
}
|
||||
return `${bytes.toFixed(2)} ${units[unitIndex]}`;
|
||||
}
|
||||
|
||||
async function upload_files(fileInput) {
|
||||
const paperclip = document.querySelector(".user-input .fa-paperclip");
|
||||
const bucket_id = uuid();
|
||||
|
||||
const formData = new FormData();
|
||||
Array.from(fileInput.files).forEach(file => {
|
||||
formData.append('files[]', file);
|
||||
});
|
||||
paperclip.classList.add("blink");
|
||||
await fetch("/backend-api/v2/files/" + bucket_id, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
let do_refine = document.getElementById("refine").checked;
|
||||
function connectToSSE(url) {
|
||||
const eventSource = new EventSource(url);
|
||||
eventSource.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.error) {
|
||||
inputCount.innerText = `Error: ${data.error.message}`;
|
||||
} else if (data.action == "load") {
|
||||
inputCount.innerText = `Read data: ${formatFileSize(data.size)}`;
|
||||
} else if (data.action == "refine") {
|
||||
inputCount.innerText = `Refine data: ${formatFileSize(data.size)}`;
|
||||
} else if (data.action == "done") {
|
||||
if (do_refine) {
|
||||
do_refine = false;
|
||||
connectToSSE(`/backend-api/v2/files/${bucket_id}?refine_chunks_with_spacy=true`);
|
||||
return;
|
||||
}
|
||||
inputCount.innerText = "Files are loaded successfully";
|
||||
messageInput.value += (messageInput.value ? "\n" : "") + JSON.stringify({bucket_id: bucket_id}) + "\n";
|
||||
paperclip.classList.remove("blink");
|
||||
fileInput.value = "";
|
||||
delete fileInput.dataset.text;
|
||||
}
|
||||
};
|
||||
eventSource.onerror = (event) => {
|
||||
eventSource.close();
|
||||
paperclip.classList.remove("blink");
|
||||
}
|
||||
}
|
||||
connectToSSE(`/backend-api/v2/files/${bucket_id}`);
|
||||
}
|
||||
|
||||
fileInput.addEventListener('change', async (event) => {
|
||||
if (fileInput.files.length) {
|
||||
type = fileInput.files[0].name.split('.').pop()
|
||||
if (type == "har") {
|
||||
return await upload_cookies();
|
||||
} else if (type != "json") {
|
||||
await upload_files(fileInput);
|
||||
}
|
||||
fileInput.dataset.type = type
|
||||
const reader = new FileReader();
|
||||
reader.addEventListener('load', async (event) => {
|
||||
fileInput.dataset.text = event.target.result;
|
||||
if (type == "json") {
|
||||
if (type == "json") {
|
||||
const reader = new FileReader();
|
||||
reader.addEventListener('load', async (event) => {
|
||||
fileInput.dataset.text = event.target.result;
|
||||
const data = JSON.parse(fileInput.dataset.text);
|
||||
if ("g4f" in data.options) {
|
||||
if (data.options && "g4f" in data.options) {
|
||||
let count = 0;
|
||||
Object.keys(data).forEach(key => {
|
||||
if (key != "options" && !localStorage.getItem(key)) {
|
||||
|
|
@ -1736,11 +1871,23 @@ fileInput.addEventListener('change', async (event) => {
|
|||
fileInput.value = "";
|
||||
inputCount.innerText = `${count} Conversations were imported successfully`;
|
||||
} else {
|
||||
await upload_cookies();
|
||||
is_cookie_file = false;
|
||||
if (Array.isArray(data)) {
|
||||
data.forEach((item) => {
|
||||
if (item.domain && item.name && item.value) {
|
||||
is_cookie_file = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
if (is_cookie_file) {
|
||||
await upload_cookies();
|
||||
} else {
|
||||
await upload_files(fileInput);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
reader.readAsText(fileInput.files[0]);
|
||||
});
|
||||
reader.readAsText(fileInput.files[0]);
|
||||
}
|
||||
} else {
|
||||
delete fileInput.dataset.text;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .gui_parser import gui_parser
|
||||
from ..cookies import read_cookie_files
|
||||
from g4f.gui import run_gui
|
||||
import g4f.cookies
|
||||
import g4f.debug
|
||||
|
||||
|
|
@ -8,7 +9,6 @@ def run_gui_args(args):
|
|||
g4f.debug.logging = True
|
||||
if not args.ignore_cookie_files:
|
||||
read_cookie_files()
|
||||
from g4f.gui import run_gui
|
||||
host = args.host
|
||||
port = args.port
|
||||
debug = args.debug
|
||||
|
|
|
|||
|
|
@ -7,17 +7,17 @@ from typing import Iterator
|
|||
from flask import send_from_directory
|
||||
from inspect import signature
|
||||
|
||||
from g4f import version, models
|
||||
from g4f import ChatCompletion, get_model_and_provider
|
||||
from g4f.errors import VersionNotFoundError
|
||||
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
|
||||
from g4f.Provider import ProviderUtils, __providers__
|
||||
from g4f.providers.base_provider import ProviderModelMixin
|
||||
from g4f.providers.retry_provider import IterListProvider
|
||||
from g4f.providers.response import BaseConversation, JsonConversation, FinishReason
|
||||
from g4f.providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
|
||||
from g4f.client.service import convert_to_provider
|
||||
from g4f import debug
|
||||
from ...errors import VersionNotFoundError
|
||||
from ...image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
|
||||
from ...tools.run_tools import iter_run_tools
|
||||
from ...Provider import ProviderUtils, __providers__
|
||||
from ...providers.base_provider import ProviderModelMixin
|
||||
from ...providers.retry_provider import IterListProvider
|
||||
from ...providers.response import BaseConversation, JsonConversation, FinishReason
|
||||
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
|
||||
from ... import version, models
|
||||
from ... import ChatCompletion, get_model_and_provider
|
||||
from ... import debug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
conversations: dict[dict[str, BaseConversation]] = {}
|
||||
|
|
@ -90,16 +90,28 @@ class Api:
|
|||
api_key = json_data.get("api_key")
|
||||
if api_key is not None:
|
||||
kwargs["api_key"] = api_key
|
||||
kwargs["tool_calls"] = [{
|
||||
"function": {
|
||||
"name": "bucket_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}]
|
||||
do_web_search = json_data.get('web_search')
|
||||
if do_web_search and provider:
|
||||
provider_handler = convert_to_provider(provider)
|
||||
if hasattr(provider_handler, "get_parameters"):
|
||||
if "web_search" in provider_handler.get_parameters():
|
||||
kwargs['web_search'] = True
|
||||
do_web_search = False
|
||||
if do_web_search:
|
||||
from ...web_search import get_search_message
|
||||
messages[-1]["content"] = get_search_message(messages[-1]["content"])
|
||||
kwargs["tool_calls"].append({
|
||||
"function": {
|
||||
"name": "safe_search_tool"
|
||||
},
|
||||
"type": "function"
|
||||
})
|
||||
action = json_data.get('action')
|
||||
if action == "continue":
|
||||
kwargs["tool_calls"].append({
|
||||
"function": {
|
||||
"name": "continue_tool"
|
||||
},
|
||||
"type": "function"
|
||||
})
|
||||
conversation = json_data.get("conversation")
|
||||
if conversation is not None:
|
||||
kwargs["conversation"] = JsonConversation(**conversation)
|
||||
|
|
@ -139,7 +151,7 @@ class Api:
|
|||
logging=False
|
||||
)
|
||||
params = {
|
||||
**provider_handler.get_parameters(as_json=True),
|
||||
**(provider_handler.get_parameters(as_json=True) if hasattr(provider_handler, "get_parameters") else {}),
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages"),
|
||||
"web_search": kwargs.get("web_search")
|
||||
|
|
@ -153,7 +165,7 @@ class Api:
|
|||
yield self._format_json("parameters", params)
|
||||
first = True
|
||||
try:
|
||||
result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider_handler})
|
||||
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler})
|
||||
for chunk in result:
|
||||
if first:
|
||||
first = False
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import sys, os
|
||||
from flask import Flask
|
||||
|
||||
if getattr(sys, 'frozen', False):
|
||||
template_folder = os.path.join(sys._MEIPASS, "client")
|
||||
else:
|
||||
template_folder = "../client"
|
||||
|
||||
app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
|
||||
def create_app() -> Flask:
|
||||
if getattr(sys, 'frozen', False):
|
||||
template_folder = os.path.join(sys._MEIPASS, "client")
|
||||
else:
|
||||
template_folder = "../client"
|
||||
return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
|
||||
|
|
@ -1,17 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import flask
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from flask import Flask, request, jsonify
|
||||
import shutil
|
||||
from flask import Flask, Response, request, jsonify
|
||||
from typing import Generator
|
||||
from pathlib import Path
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from g4f.image import is_allowed_extension, to_image
|
||||
from g4f.client.service import convert_to_provider
|
||||
from g4f.providers.asyncio import to_sync_generator
|
||||
from g4f.errors import ProviderNotFoundError
|
||||
from g4f.cookies import get_cookies_dir
|
||||
from ...image import is_allowed_extension, to_image
|
||||
from ...client.service import convert_to_provider
|
||||
from ...providers.asyncio import to_sync_generator
|
||||
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
|
||||
from ...errors import ProviderNotFoundError
|
||||
from ...cookies import get_cookies_dir
|
||||
from .api import Api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -96,6 +101,85 @@ class Backend_Api(Api):
|
|||
}
|
||||
}
|
||||
|
||||
@app.route('/backend-api/v2/buckets', methods=['GET'])
|
||||
def list_buckets():
|
||||
try:
|
||||
buckets = get_buckets()
|
||||
if buckets is None:
|
||||
return jsonify({"error": {"message": "Error accessing bucket directory"}}), 500
|
||||
sanitized_buckets = [secure_filename(b) for b in buckets]
|
||||
return jsonify(sanitized_buckets), 200
|
||||
except Exception as e:
|
||||
return jsonify({"error": {"message": str(e)}}), 500
|
||||
|
||||
@app.route('/backend-api/v2/files/<bucket_id>', methods=['GET', 'DELETE'])
|
||||
def manage_files(bucket_id: str):
|
||||
bucket_id = secure_filename(bucket_id)
|
||||
bucket_dir = get_bucket_dir(secure_filename(bucket_id))
|
||||
|
||||
if not os.path.isdir(bucket_dir):
|
||||
return jsonify({"error": {"message": "Bucket directory not found"}}), 404
|
||||
|
||||
if request.method == 'DELETE':
|
||||
try:
|
||||
shutil.rmtree(bucket_dir)
|
||||
return jsonify({"message": "Bucket deleted successfully"}), 200
|
||||
except OSError as e:
|
||||
return jsonify({"error": {"message": f"Error deleting bucket: {str(e)}"}}), 500
|
||||
except Exception as e:
|
||||
return jsonify({"error": {"message": str(e)}}), 500
|
||||
|
||||
delete_files = request.args.get('delete_files', True)
|
||||
refine_chunks_with_spacy = request.args.get('refine_chunks_with_spacy', False)
|
||||
event_stream = 'text/event-stream' in request.headers.get('Accept', '')
|
||||
mimetype = "text/event-stream" if event_stream else "text/plain";
|
||||
return Response(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), mimetype=mimetype)
|
||||
|
||||
@self.app.route('/backend-api/v2/files/<bucket_id>', methods=['POST'])
|
||||
def upload_files(bucket_id: str):
|
||||
bucket_id = secure_filename(bucket_id)
|
||||
bucket_dir = get_bucket_dir(bucket_id)
|
||||
os.makedirs(bucket_dir, exist_ok=True)
|
||||
filenames = []
|
||||
for file in request.files.getlist('files[]'):
|
||||
try:
|
||||
filename = secure_filename(file.filename)
|
||||
if supports_filename(filename):
|
||||
with open(os.path.join(bucket_dir, filename), 'wb') as f:
|
||||
shutil.copyfileobj(file.stream, f)
|
||||
filenames.append(filename)
|
||||
finally:
|
||||
file.stream.close()
|
||||
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
|
||||
[f.write(f"{filename}\n") for filename in filenames]
|
||||
return {"bucket_id": bucket_id, "files": filenames}
|
||||
|
||||
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
|
||||
def upload_file(bucket_id, filename):
|
||||
bucket_id = secure_filename(bucket_id)
|
||||
bucket_dir = get_bucket_dir(bucket_id)
|
||||
filename = secure_filename(filename)
|
||||
bucket_path = Path(bucket_dir)
|
||||
|
||||
if not supports_filename(filename):
|
||||
return jsonify({"error": {"message": f"File type not allowed"}}), 400
|
||||
|
||||
if not bucket_path.exists():
|
||||
bucket_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
file_path = bucket_path / filename
|
||||
file_data = request.get_data()
|
||||
if not file_data:
|
||||
return jsonify({"error": {"message": "No file data received"}}), 400
|
||||
|
||||
with open(str(file_path), 'wb') as f:
|
||||
f.write(file_data)
|
||||
|
||||
return jsonify({"message": f"File '{filename}' uploaded successfully to bucket '{bucket_id}'"}), 201
|
||||
except Exception as e:
|
||||
return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
|
||||
|
||||
def upload_cookies(self):
|
||||
file = None
|
||||
if "file" in request.files:
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ...web_search import SearchResults, search, get_search_message
|
||||
from ...tools.web_search import SearchResults, search, get_search_message
|
||||
|
|
@ -32,8 +32,6 @@ except ImportError:
|
|||
|
||||
from .. import debug
|
||||
from .raise_for_status import raise_for_status
|
||||
from ..webdriver import WebDriver, WebDriverSession
|
||||
from ..webdriver import bypass_cloudflare, get_driver_cookies
|
||||
from ..errors import MissingRequirementsError
|
||||
from ..typing import Cookies
|
||||
from .defaults import DEFAULT_HEADERS, WEBVIEW_HAEDERS
|
||||
|
|
@ -66,69 +64,6 @@ async def get_args_from_webview(url: str) -> dict:
|
|||
window.destroy()
|
||||
return {"headers": headers, "cookies": cookies}
|
||||
|
||||
def get_args_from_browser(
|
||||
url: str,
|
||||
webdriver: WebDriver = None,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
do_bypass_cloudflare: bool = True,
|
||||
virtual_display: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Session object using a WebDriver to handle cookies and headers.
|
||||
|
||||
Args:
|
||||
url (str): The URL to navigate to using the WebDriver.
|
||||
webdriver (WebDriver, optional): The WebDriver instance to use.
|
||||
proxy (str, optional): Proxy server to use for the Session.
|
||||
timeout (int, optional): Timeout in seconds for the WebDriver.
|
||||
|
||||
Returns:
|
||||
Session: A Session object configured with cookies and headers from the WebDriver.
|
||||
"""
|
||||
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=virtual_display) as driver:
|
||||
if do_bypass_cloudflare:
|
||||
bypass_cloudflare(driver, url, timeout)
|
||||
headers = {
|
||||
**DEFAULT_HEADERS,
|
||||
'referer': url,
|
||||
}
|
||||
if not hasattr(driver, "requests"):
|
||||
headers["user-agent"] = driver.execute_script("return navigator.userAgent")
|
||||
else:
|
||||
for request in driver.requests:
|
||||
if request.url.startswith(url):
|
||||
for key, value in request.headers.items():
|
||||
if key in (
|
||||
"accept-encoding",
|
||||
"accept-language",
|
||||
"user-agent",
|
||||
"sec-ch-ua",
|
||||
"sec-ch-ua-platform",
|
||||
"sec-ch-ua-arch",
|
||||
"sec-ch-ua-full-version",
|
||||
"sec-ch-ua-platform-version",
|
||||
"sec-ch-ua-bitness"
|
||||
):
|
||||
headers[key] = value
|
||||
break
|
||||
cookies = get_driver_cookies(driver)
|
||||
return {
|
||||
'cookies': cookies,
|
||||
'headers': headers,
|
||||
}
|
||||
|
||||
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
|
||||
if not has_curl_cffi:
|
||||
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
||||
args = get_args_from_browser(url, webdriver, proxy, timeout)
|
||||
return Session(
|
||||
**args,
|
||||
proxies={"https": proxy, "http": proxy},
|
||||
timeout=timeout,
|
||||
impersonate="chrome"
|
||||
)
|
||||
|
||||
def get_cookie_params_from_dict(cookies: Cookies, url: str = None, domain: str = None) -> list[CookieParam]:
|
||||
[CookieParam.from_json({
|
||||
"name": key,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
|||
return
|
||||
text = await response.text()
|
||||
if message is None:
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
|
||||
if message == "HTML content":
|
||||
if response.status == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
|
|
|
|||
511
g4f/tools/files.py
Normal file
511
g4f/tools/files.py
Normal file
|
|
@ -0,0 +1,511 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
from aiohttp import ClientSession, ClientError, ClientResponse, ClientTimeout
|
||||
import urllib.parse
|
||||
import time
|
||||
import zipfile
|
||||
import asyncio
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
try:
|
||||
from werkzeug.utils import secure_filename
|
||||
except ImportError:
|
||||
secure_filename = os.path.basename
|
||||
|
||||
try:
|
||||
import PyPDF2
|
||||
from PyPDF2.errors import PdfReadError
|
||||
has_pypdf2 = True
|
||||
except ImportError:
|
||||
has_pypdf2 = False
|
||||
try:
|
||||
import pdfplumber
|
||||
has_pdfplumber = True
|
||||
except ImportError:
|
||||
has_pdfplumber = False
|
||||
try:
|
||||
from pdfminer.high_level import extract_text
|
||||
has_pdfminer = True
|
||||
except ImportError:
|
||||
has_pdfminer = False
|
||||
try:
|
||||
from docx import Document
|
||||
has_docx = True
|
||||
except ImportError:
|
||||
has_docx = False
|
||||
try:
|
||||
import docx2txt
|
||||
has_docx2txt = True
|
||||
except ImportError:
|
||||
has_docx2txt = False
|
||||
try:
|
||||
from odf.opendocument import load
|
||||
from odf.text import P
|
||||
has_odfpy = True
|
||||
except ImportError:
|
||||
has_odfpy = False
|
||||
try:
|
||||
import ebooklib
|
||||
from ebooklib import epub
|
||||
has_ebooklib = True
|
||||
except ImportError:
|
||||
has_ebooklib = False
|
||||
try:
|
||||
import pandas as pd
|
||||
has_openpyxl = True
|
||||
except ImportError:
|
||||
has_openpyxl = False
|
||||
try:
|
||||
import spacy
|
||||
has_spacy = True
|
||||
except:
|
||||
has_spacy = False
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
has_beautifulsoup4 = True
|
||||
except ImportError:
|
||||
has_beautifulsoup4 = False
|
||||
|
||||
from .web_search import scrape_text
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..errors import MissingRequirementsError
|
||||
from .. import debug
|
||||
|
||||
PLAIN_FILE_EXTENSIONS = ["txt", "xml", "json", "js", "har", "sh", "py", "php", "css", "yaml", "sql", "log", "csv", "twig", "md"]
|
||||
PLAIN_CACHE = "plain.cache"
|
||||
DOWNLOADS_FILE = "downloads.json"
|
||||
FILE_LIST = "files.txt"
|
||||
|
||||
def supports_filename(filename: str):
|
||||
if filename.endswith(".pdf"):
|
||||
if has_pypdf2:
|
||||
return True
|
||||
elif has_pdfplumber:
|
||||
return True
|
||||
elif has_pdfminer:
|
||||
return True
|
||||
raise MissingRequirementsError(f'Install "pypdf2" requirements | pip install -U g4f[files]')
|
||||
elif filename.endswith(".docx"):
|
||||
if has_docx:
|
||||
return True
|
||||
elif has_docx2txt:
|
||||
return True
|
||||
raise MissingRequirementsError(f'Install "docx" requirements | pip install -U g4f[files]')
|
||||
elif has_odfpy and filename.endswith(".odt"):
|
||||
return True
|
||||
elif has_ebooklib and filename.endswith(".epub"):
|
||||
return True
|
||||
elif has_openpyxl and filename.endswith(".xlsx"):
|
||||
return True
|
||||
elif filename.endswith(".html"):
|
||||
if not has_beautifulsoup4:
|
||||
raise MissingRequirementsError(f'Install "beautifulsoup4" requirements | pip install -U g4f[files]')
|
||||
return True
|
||||
elif filename.endswith(".zip"):
|
||||
return True
|
||||
elif filename.endswith("package-lock.json") and filename != FILE_LIST:
|
||||
return False
|
||||
else:
|
||||
extension = os.path.splitext(filename)[1][1:]
|
||||
if extension in PLAIN_FILE_EXTENSIONS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_bucket_dir(bucket_id: str):
|
||||
bucket_dir = os.path.join(get_cookies_dir(), "buckets", bucket_id)
|
||||
return bucket_dir
|
||||
|
||||
def get_buckets():
|
||||
buckets_dir = os.path.join(get_cookies_dir(), "buckets")
|
||||
try:
|
||||
return [d for d in os.listdir(buckets_dir) if os.path.isdir(os.path.join(buckets_dir, d))]
|
||||
except OSError as e:
|
||||
return None
|
||||
|
||||
def spacy_refine_chunks(source_iterator):
|
||||
if not has_spacy:
|
||||
raise MissingRequirementsError(f'Install "spacy" requirements | pip install -U g4f[files]')
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
for page in source_iterator:
|
||||
doc = nlp(page)
|
||||
#for chunk in doc.noun_chunks:
|
||||
# yield " ".join([token.lemma_ for token in chunk if not token.is_stop])
|
||||
# for token in doc:
|
||||
# if not token.is_space:
|
||||
# yield token.lemma_.lower()
|
||||
# yield " "
|
||||
sentences = list(doc.sents)
|
||||
summary = sorted(sentences, key=lambda x: len(x.text), reverse=True)[:2]
|
||||
for sent in summary:
|
||||
yield sent.text
|
||||
|
||||
def get_filenames(bucket_dir: Path):
|
||||
files = bucket_dir / FILE_LIST
|
||||
with files.open('r') as f:
|
||||
return [filename.strip() for filename in f.readlines()]
|
||||
|
||||
def stream_read_files(bucket_dir: Path, filenames: list) -> Iterator[str]:
|
||||
for filename in filenames:
|
||||
file_path: Path = bucket_dir / filename
|
||||
if not file_path.exists() and 0 > file_path.lstat().st_size:
|
||||
continue
|
||||
extension = os.path.splitext(filename)[1][1:]
|
||||
if filename.endswith(".zip"):
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(bucket_dir)
|
||||
try:
|
||||
yield from stream_read_files(bucket_dir, [f for f in zip_ref.namelist() if supports_filename(f)])
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
finally:
|
||||
for unlink in zip_ref.namelist()[::-1]:
|
||||
filepath = os.path.join(bucket_dir, unlink)
|
||||
if os.path.exists(filepath):
|
||||
if os.path.isdir(filepath):
|
||||
os.rmdir(filepath)
|
||||
else:
|
||||
os.unlink(filepath)
|
||||
continue
|
||||
yield f"```{filename}\n"
|
||||
if has_pypdf2 and filename.endswith(".pdf"):
|
||||
try:
|
||||
reader = PyPDF2.PdfReader(file_path)
|
||||
for page_num in range(len(reader.pages)):
|
||||
page = reader.pages[page_num]
|
||||
yield page.extract_text()
|
||||
except PdfReadError:
|
||||
continue
|
||||
if has_pdfplumber and filename.endswith(".pdf"):
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page in pdf.pages:
|
||||
yield page.extract_text()
|
||||
if has_pdfminer and filename.endswith(".pdf"):
|
||||
yield extract_text(file_path)
|
||||
elif has_docx and filename.endswith(".docx"):
|
||||
doc = Document(file_path)
|
||||
for para in doc.paragraphs:
|
||||
yield para.text
|
||||
elif has_docx2txt and filename.endswith(".docx"):
|
||||
yield docx2txt.process(file_path)
|
||||
elif has_odfpy and filename.endswith(".odt"):
|
||||
textdoc = load(file_path)
|
||||
allparas = textdoc.getElementsByType(P)
|
||||
for p in allparas:
|
||||
yield p.firstChild.data if p.firstChild else ""
|
||||
elif has_ebooklib and filename.endswith(".epub"):
|
||||
book = epub.read_epub(file_path)
|
||||
for doc_item in book.get_items():
|
||||
if doc_item.get_type() == ebooklib.ITEM_DOCUMENT:
|
||||
yield doc_item.get_content().decode(errors='ignore')
|
||||
elif has_openpyxl and filename.endswith(".xlsx"):
|
||||
df = pd.read_excel(file_path)
|
||||
for row in df.itertuples(index=False):
|
||||
yield " ".join(str(cell) for cell in row)
|
||||
elif has_beautifulsoup4 and filename.endswith(".html"):
|
||||
yield from scrape_text(file_path.read_text(errors="ignore"))
|
||||
elif extension in PLAIN_FILE_EXTENSIONS:
|
||||
yield file_path.read_text(errors="ignore")
|
||||
yield f"\n```\n\n"
|
||||
|
||||
def cache_stream(stream: Iterator[str], bucket_dir: Path) -> Iterator[str]:
|
||||
cache_file = bucket_dir / PLAIN_CACHE
|
||||
tmp_file = bucket_dir / f"{PLAIN_CACHE}.{time.time()}.tmp"
|
||||
if cache_file.exists():
|
||||
for chunk in read_path_chunked(cache_file):
|
||||
yield chunk
|
||||
return
|
||||
with open(tmp_file, "w") as f:
|
||||
for chunk in stream:
|
||||
f.write(chunk)
|
||||
yield chunk
|
||||
tmp_file.rename(cache_file)
|
||||
|
||||
def is_complete(data: str):
|
||||
return data.endswith("\n```\n\n") and data.count("```") % 2 == 0
|
||||
|
||||
def read_path_chunked(path: Path):
|
||||
with path.open("r", encoding='utf-8') as f:
|
||||
current_chunk_size = 0
|
||||
buffer = ""
|
||||
for line in f:
|
||||
current_chunk_size += len(line.encode('utf-8'))
|
||||
buffer += line
|
||||
if current_chunk_size >= 4096:
|
||||
if is_complete(buffer) or current_chunk_size >= 8192:
|
||||
yield buffer
|
||||
buffer = ""
|
||||
current_chunk_size = 0
|
||||
if current_chunk_size > 0:
|
||||
yield buffer
|
||||
|
||||
def read_bucket(bucket_dir: Path):
|
||||
bucket_dir = Path(bucket_dir)
|
||||
cache_file = bucket_dir / PLAIN_CACHE
|
||||
spacy_file = bucket_dir / f"spacy_0001.cache"
|
||||
if not spacy_file.exists():
|
||||
yield cache_file.read_text()
|
||||
for idx in range(1, 1000):
|
||||
spacy_file = bucket_dir / f"spacy_{idx:04d}.cache"
|
||||
plain_file = bucket_dir / f"plain_{idx:04d}.cache"
|
||||
if spacy_file.exists():
|
||||
yield spacy_file.read_text()
|
||||
elif plain_file.exists():
|
||||
yield plain_file.read_text()
|
||||
else:
|
||||
break
|
||||
|
||||
def stream_read_parts_and_refine(bucket_dir: Path, delete_files: bool = False) -> Iterator[str]:
|
||||
cache_file = bucket_dir / PLAIN_CACHE
|
||||
space_file = Path(bucket_dir) / f"spacy_0001.cache"
|
||||
part_one = bucket_dir / f"plain_0001.cache"
|
||||
if not space_file.exists() and not part_one.exists() and cache_file.exists():
|
||||
split_file_by_size_and_newline(cache_file, bucket_dir)
|
||||
for idx in range(1, 1000):
|
||||
part = bucket_dir / f"plain_{idx:04d}.cache"
|
||||
tmp_file = Path(bucket_dir) / f"spacy_{idx:04d}.{time.time()}.tmp"
|
||||
cache_file = Path(bucket_dir) / f"spacy_{idx:04d}.cache"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r") as f:
|
||||
yield f.read()
|
||||
continue
|
||||
if not part.exists():
|
||||
break
|
||||
with tmp_file.open("w") as f:
|
||||
for chunk in spacy_refine_chunks(read_path_chunked(part)):
|
||||
f.write(chunk)
|
||||
yield chunk
|
||||
tmp_file.rename(cache_file)
|
||||
if delete_files:
|
||||
part.unlink()
|
||||
|
||||
def split_file_by_size_and_newline(input_filename, output_dir, chunk_size_bytes=1024*1024): # 1MB
|
||||
"""Splits a file into chunks of approximately chunk_size_bytes, splitting only at newline characters.
|
||||
|
||||
Args:
|
||||
input_filename: Path to the input file.
|
||||
output_prefix: Prefix for the output files (e.g., 'output_part_').
|
||||
chunk_size_bytes: Desired size of each chunk in bytes.
|
||||
"""
|
||||
split_filename = os.path.splitext(os.path.basename(input_filename))
|
||||
output_prefix = os.path.join(output_dir, split_filename[0] + "_")
|
||||
|
||||
with open(input_filename, 'r', encoding='utf-8') as infile:
|
||||
chunk_num = 1
|
||||
current_chunk = ""
|
||||
current_chunk_size = 0
|
||||
|
||||
for line in infile:
|
||||
current_chunk += line
|
||||
current_chunk_size += len(line.encode('utf-8'))
|
||||
|
||||
if current_chunk_size >= chunk_size_bytes:
|
||||
if is_complete(current_chunk) or current_chunk_size >= chunk_size_bytes * 2:
|
||||
output_filename = f"{output_prefix}{chunk_num:04d}{split_filename[1]}"
|
||||
with open(output_filename, 'w', encoding='utf-8') as outfile:
|
||||
outfile.write(current_chunk)
|
||||
current_chunk = ""
|
||||
current_chunk_size = 0
|
||||
chunk_num += 1
|
||||
|
||||
# Write the last chunk
|
||||
if current_chunk:
|
||||
output_filename = f"{output_prefix}{chunk_num:04d}{split_filename[1]}"
|
||||
with open(output_filename, 'w', encoding='utf-8') as outfile:
|
||||
outfile.write(current_chunk)
|
||||
|
||||
async def get_filename(response: ClientResponse):
|
||||
"""
|
||||
Attempts to extract a filename from an aiohttp response. Prioritizes Content-Disposition, then URL.
|
||||
|
||||
Args:
|
||||
response: The aiohttp ClientResponse object.
|
||||
|
||||
Returns:
|
||||
The filename as a string, or None if it cannot be determined.
|
||||
"""
|
||||
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
if content_disposition:
|
||||
try:
|
||||
filename = content_disposition.split('filename=')[1].strip('"')
|
||||
if filename:
|
||||
return secure_filename(filename)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
content_type = response.headers.get('Content-Type')
|
||||
url = str(response.url)
|
||||
if content_type and url:
|
||||
extension = await get_file_extension(response)
|
||||
if extension:
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
sha256_hash = hashlib.sha256(url.encode()).digest()
|
||||
base64_encoded = base64.b32encode(sha256_hash).decode().lower()
|
||||
return f"{parsed_url.netloc} {parsed_url.path[1:].replace('/', '_')} {base64_encoded[:6]}{extension}"
|
||||
|
||||
return None
|
||||
|
||||
async def get_file_extension(response: ClientResponse):
|
||||
"""
|
||||
Attempts to determine the file extension from an aiohttp response. Improved to handle more types.
|
||||
|
||||
Args:
|
||||
response: The aiohttp ClientResponse object.
|
||||
|
||||
Returns:
|
||||
The file extension (e.g., ".html", ".json", ".pdf", ".zip", ".md", ".txt") as a string,
|
||||
or None if it cannot be determined.
|
||||
"""
|
||||
|
||||
content_type = response.headers.get('Content-Type')
|
||||
if content_type:
|
||||
if "html" in content_type.lower():
|
||||
return ".html"
|
||||
elif "json" in content_type.lower():
|
||||
return ".json"
|
||||
elif "pdf" in content_type.lower():
|
||||
return ".pdf"
|
||||
elif "zip" in content_type.lower():
|
||||
return ".zip"
|
||||
elif "text/plain" in content_type.lower():
|
||||
return ".txt"
|
||||
elif "markdown" in content_type.lower():
|
||||
return ".md"
|
||||
|
||||
url = str(response.url)
|
||||
if url:
|
||||
return Path(url).suffix.lower()
|
||||
|
||||
return None
|
||||
|
||||
def read_links(html: str, base: str) -> set[str]:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main",
|
||||
".main-content-wrapper",
|
||||
".main-content",
|
||||
".emt-container-inner",
|
||||
".content-wrapper",
|
||||
"#content",
|
||||
"#mainContent",
|
||||
]:
|
||||
select = soup.select_one(selector)
|
||||
if select:
|
||||
soup = select
|
||||
break
|
||||
urls = []
|
||||
for link in soup.select("a"):
|
||||
if "rel" not in link.attrs or "nofollow" not in link.attrs["rel"]:
|
||||
url = link.attrs.get("href")
|
||||
if url and url.startswith("https://"):
|
||||
urls.append(url.split("#")[0])
|
||||
return set([urllib.parse.urljoin(base, link) for link in urls])
|
||||
|
||||
async def download_urls(
|
||||
bucket_dir: Path,
|
||||
urls: list[str],
|
||||
max_depth: int = 2,
|
||||
loaded_urls: set[str] = set(),
|
||||
lock: asyncio.Lock = None,
|
||||
delay: int = 3,
|
||||
group_size: int = 5,
|
||||
timeout: int = 10,
|
||||
proxy: Optional[str] = None
|
||||
) -> list[str]:
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
async with ClientSession(
|
||||
connector=get_connector(proxy=proxy),
|
||||
timeout=ClientTimeout(timeout)
|
||||
) as session:
|
||||
async def download_url(url: str) -> str:
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status()
|
||||
filename = await get_filename(response)
|
||||
if not filename:
|
||||
print(f"Failed to get filename for {url}")
|
||||
return None
|
||||
newfiles = [filename]
|
||||
if filename.endswith(".html") and max_depth > 0:
|
||||
new_urls = read_links(await response.text(), str(response.url))
|
||||
async with lock:
|
||||
new_urls = [new_url for new_url in new_urls if new_url not in loaded_urls]
|
||||
[loaded_urls.add(url) for url in new_urls]
|
||||
if new_urls:
|
||||
for i in range(0, len(new_urls), group_size):
|
||||
newfiles += await download_urls(bucket_dir, new_urls[i:i + group_size], max_depth - 1, loaded_urls, lock, delay + 1)
|
||||
await asyncio.sleep(delay)
|
||||
if supports_filename(filename) and filename != DOWNLOADS_FILE:
|
||||
target = bucket_dir / filename
|
||||
with target.open("wb") as f:
|
||||
async for chunk in response.content.iter_chunked(4096):
|
||||
f.write(chunk)
|
||||
return newfiles
|
||||
except (ClientError, asyncio.TimeoutError) as e:
|
||||
debug.log(f"Download failed: {e.__class__.__name__}: {e}")
|
||||
return None
|
||||
files = set()
|
||||
for results in await asyncio.gather(*[download_url(url) for url in urls]):
|
||||
if results:
|
||||
[files.add(url) for url in results]
|
||||
return files
|
||||
|
||||
def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]:
|
||||
bucket_dir = Path(bucket_dir)
|
||||
bucket_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
download_file = bucket_dir / DOWNLOADS_FILE
|
||||
if download_file.exists():
|
||||
urls = []
|
||||
with download_file.open('r') as f:
|
||||
data = json.load(f)
|
||||
download_file.unlink()
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if "url" in item:
|
||||
urls.append(item["url"])
|
||||
if urls:
|
||||
filenames = asyncio.run(download_urls(bucket_dir, urls))
|
||||
with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f:
|
||||
[f.write(f"{filename}\n") for filename in filenames if filename]
|
||||
|
||||
if refine_chunks_with_spacy:
|
||||
size = 0
|
||||
for chunk in stream_read_parts_and_refine(bucket_dir, delete_files):
|
||||
if event_stream:
|
||||
size += len(chunk)
|
||||
yield f'data: {json.dumps({"action": "refine", "size": size})}\n\n'
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir))
|
||||
streaming = cache_stream(streaming, bucket_dir)
|
||||
size = 0
|
||||
for chunk in streaming:
|
||||
if event_stream:
|
||||
size += len(chunk)
|
||||
yield f'data: {json.dumps({"action": "load", "size": size})}\n\n'
|
||||
else:
|
||||
yield chunk
|
||||
files_txt = os.path.join(bucket_dir, FILE_LIST)
|
||||
if delete_files and os.path.exists(files_txt):
|
||||
for filename in get_filenames(bucket_dir):
|
||||
if os.path.exists(os.path.join(bucket_dir, filename)):
|
||||
os.remove(os.path.join(bucket_dir, filename))
|
||||
os.remove(files_txt)
|
||||
if event_stream:
|
||||
yield f'data: {json.dumps({"action": "delete_files"})}\n\n'
|
||||
if event_stream:
|
||||
yield f'data: {json.dumps({"action": "done"})}\n\n'
|
||||
except Exception as e:
|
||||
if event_stream:
|
||||
yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n'
|
||||
raise e
|
||||
87
g4f/tools/run_tools.py
Normal file
87
g4f/tools/run_tools.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional, Callable, AsyncIterator
|
||||
|
||||
from ..typing import Messages
|
||||
from ..providers.helper import filter_none
|
||||
from ..client.helper import to_async_iterator
|
||||
from .web_search import do_search, get_search_message
|
||||
from .files import read_bucket, get_bucket_dir
|
||||
from .. import debug
|
||||
|
||||
def validate_arguments(data: dict) -> dict:
|
||||
if "arguments" in data:
|
||||
if isinstance(data["arguments"], str):
|
||||
data["arguments"] = json.loads(data["arguments"])
|
||||
if not isinstance(data["arguments"], dict):
|
||||
raise ValueError("Tool function arguments must be a dictionary or a json string")
|
||||
else:
|
||||
return filter_none(**data["arguments"])
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if tool.get("type") == "function":
|
||||
if tool.get("function", {}).get("name") == "search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
messages = messages.copy()
|
||||
messages[-1]["content"] = await do_search(
|
||||
messages[-1]["content"],
|
||||
**tool["function"]["arguments"]
|
||||
)
|
||||
elif tool.get("function", {}).get("name") == "continue":
|
||||
last_line = messages[-1]["content"].strip().splitlines()[-1]
|
||||
content = f"Continue writing the story after this line start with a plus sign if you begin a new word.\n{last_line}"
|
||||
messages.append({"role": "user", "content": content})
|
||||
response = async_iter_callback(model=model, messages=messages, **kwargs)
|
||||
if not hasattr(response, "__aiter__"):
|
||||
response = to_async_iterator(response)
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
def iter_run_tools(
|
||||
iter_callback: Callable,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
provider: Optional[str] = None,
|
||||
tool_calls: Optional[list] = None,
|
||||
**kwargs
|
||||
) -> AsyncIterator:
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if tool.get("type") == "function":
|
||||
if tool.get("function", {}).get("name") == "search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
messages[-1]["content"] = get_search_message(
|
||||
messages[-1]["content"],
|
||||
raise_search_exceptions=True,
|
||||
**tool["function"]["arguments"]
|
||||
)
|
||||
elif tool.get("function", {}).get("name") == "safe_search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
try:
|
||||
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], **tool["function"]["arguments"]))
|
||||
except Exception as e:
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
# Enable provider native web search
|
||||
kwargs["web_search"] = True
|
||||
elif tool.get("function", {}).get("name") == "continue_tool":
|
||||
if provider not in ("OpenaiAccount", "HuggingFace"):
|
||||
last_line = messages[-1]["content"].strip().splitlines()[-1]
|
||||
content = f"continue after this line:\n{last_line}"
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
# Enable provider native continue
|
||||
if "action" not in kwargs:
|
||||
kwargs["action"] = "continue"
|
||||
elif tool.get("function", {}).get("name") == "bucket_tool":
|
||||
def on_bucket(match):
|
||||
return "".join(read_bucket(get_bucket_dir(match.group(1))))
|
||||
messages[-1]["content"] = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, messages[-1]["content"])
|
||||
print(messages[-1])
|
||||
return iter_callback(model=model, messages=messages, provider=provider, **kwargs)
|
||||
|
|
@ -1,6 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientError
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
from duckduckgo_search.exceptions import DuckDuckGoSearchException
|
||||
|
|
@ -8,8 +12,15 @@ try:
|
|||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
from .errors import MissingRequirementsError
|
||||
from . import debug
|
||||
try:
|
||||
import spacy
|
||||
has_spacy = True
|
||||
except:
|
||||
has_spacy = False
|
||||
from typing import Iterator
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import MissingRequirementsError
|
||||
from .. import debug
|
||||
|
||||
import asyncio
|
||||
|
||||
|
|
@ -52,7 +63,7 @@ class SearchResultEntry():
|
|||
def set_text(self, text: str):
|
||||
self.text = text
|
||||
|
||||
def scrape_text(html: str, max_words: int = None) -> str:
|
||||
def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main",
|
||||
|
|
@ -72,14 +83,10 @@ def scrape_text(html: str, max_words: int = None) -> str:
|
|||
select = soup.select_one(remove)
|
||||
if select:
|
||||
select.extract()
|
||||
clean_text = ""
|
||||
for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"):
|
||||
text = paragraph.get_text()
|
||||
for line in text.splitlines():
|
||||
words = []
|
||||
for word in line.replace("\t", " ").split(" "):
|
||||
if word:
|
||||
words.append(word)
|
||||
|
||||
for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"):
|
||||
for line in paragraph.text.splitlines():
|
||||
words = [word for word in line.replace("\t", " ").split(" ") if word]
|
||||
count = len(words)
|
||||
if not count:
|
||||
continue
|
||||
|
|
@ -87,18 +94,23 @@ def scrape_text(html: str, max_words: int = None) -> str:
|
|||
max_words -= count
|
||||
if max_words <= 0:
|
||||
break
|
||||
if clean_text:
|
||||
clean_text += "\n"
|
||||
clean_text += " ".join(words)
|
||||
|
||||
return clean_text
|
||||
yield " ".join(words) + "\n"
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
|
||||
try:
|
||||
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
|
||||
bucket_dir.mkdir(parents=True, exist_ok=True)
|
||||
md5_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
cache_file = bucket_dir / f"{url.split('/')[3]}.{md5_hash}.txt"
|
||||
if cache_file.exists():
|
||||
return cache_file.read_text()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text()
|
||||
return scrape_text(html, max_words)
|
||||
text = "".join(scrape_text(html, max_words))
|
||||
with open(cache_file, "w") as f:
|
||||
f.write(text)
|
||||
return text
|
||||
except ClientError:
|
||||
return
|
||||
|
||||
|
|
@ -113,7 +125,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
|
|||
safesearch="moderate",
|
||||
timelimit="y",
|
||||
max_results=max_results,
|
||||
backend=backend, # Changed from 'api' to 'auto'
|
||||
backend=backend,
|
||||
):
|
||||
results.append(SearchResultEntry(
|
||||
result["title"],
|
||||
|
|
@ -149,8 +161,20 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
|
|||
|
||||
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
|
||||
if query is None:
|
||||
query = prompt
|
||||
search_results = await search(query, **kwargs)
|
||||
query = spacy_get_keywords(prompt)
|
||||
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode()
|
||||
md5_hash = hashlib.md5(json_bytes).hexdigest()
|
||||
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search"
|
||||
bucket_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r") as f:
|
||||
search_results = f.read()
|
||||
else:
|
||||
search_results = await search(query, **kwargs)
|
||||
with open(cache_file, "w") as f:
|
||||
f.write(str(search_results))
|
||||
|
||||
new_prompt = f"""
|
||||
{search_results}
|
||||
|
||||
|
|
@ -170,3 +194,37 @@ def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) ->
|
|||
raise e
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
return prompt
|
||||
|
||||
def spacy_get_keywords(text: str):
|
||||
if not has_spacy:
|
||||
return text
|
||||
|
||||
# Load the spaCy language model
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
# Process the query
|
||||
doc = nlp(text)
|
||||
|
||||
# Extract keywords based on POS and named entities
|
||||
keywords = []
|
||||
for token in doc:
|
||||
# Filter for nouns, proper nouns, and adjectives
|
||||
if token.pos_ in {"NOUN", "PROPN", "ADJ"} and not token.is_stop:
|
||||
keywords.append(token.lemma_)
|
||||
|
||||
# Add named entities as keywords
|
||||
for ent in doc.ents:
|
||||
keywords.append(ent.text)
|
||||
|
||||
# Remove duplicates and print keywords
|
||||
keywords = list(set(keywords))
|
||||
#print("Keyword:", keywords)
|
||||
|
||||
#keyword_freq = Counter(keywords)
|
||||
#keywords = keyword_freq.most_common()
|
||||
#print("Keyword Frequencies:", keywords)
|
||||
|
||||
keywords = [chunk.text for chunk in doc.noun_chunks if not chunk.root.is_stop]
|
||||
#print("Phrases:", keywords)
|
||||
|
||||
return keywords
|
||||
257
g4f/webdriver.py
257
g4f/webdriver.py
|
|
@ -1,257 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from platformdirs import user_config_dir
|
||||
from undetected_chromedriver import Chrome, ChromeOptions, find_chrome_executable
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from selenium.webdriver.remote.webelement import WebElement
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.common.exceptions import NoSuchElementException
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
from typing import Type as WebDriver
|
||||
has_requirements = False
|
||||
|
||||
import time
|
||||
from shutil import which
|
||||
from os import path
|
||||
from os import access, R_OK
|
||||
from .typing import Cookies
|
||||
from .errors import MissingRequirementsError
|
||||
from . import debug
|
||||
|
||||
try:
|
||||
from pyvirtualdisplay import Display
|
||||
has_pyvirtualdisplay = True
|
||||
except ImportError:
|
||||
has_pyvirtualdisplay = False
|
||||
|
||||
try:
|
||||
from undetected_chromedriver import Chrome as _Chrome, ChromeOptions
|
||||
from seleniumwire.webdriver import InspectRequestsMixin, DriverCommonMixin
|
||||
|
||||
class Chrome(InspectRequestsMixin, DriverCommonMixin, _Chrome):
|
||||
def __init__(self, *args, options=None, seleniumwire_options={}, **kwargs):
|
||||
if options is None:
|
||||
options = ChromeOptions()
|
||||
config = self._setup_backend(seleniumwire_options)
|
||||
options.add_argument(f"--proxy-server={config['proxy']['httpProxy']}")
|
||||
options.add_argument("--proxy-bypass-list=<-loopback>")
|
||||
options.add_argument("--ignore-certificate-errors")
|
||||
super().__init__(*args, options=options, **kwargs)
|
||||
has_seleniumwire = True
|
||||
except:
|
||||
has_seleniumwire = False
|
||||
|
||||
def get_browser(
|
||||
user_data_dir: str = None,
|
||||
headless: bool = False,
|
||||
proxy: str = None,
|
||||
options: ChromeOptions = None
|
||||
) -> WebDriver:
|
||||
"""
|
||||
Creates and returns a Chrome WebDriver with specified options.
|
||||
|
||||
Args:
|
||||
user_data_dir (str, optional): Directory for user data. If None, uses default directory.
|
||||
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
|
||||
proxy (str, optional): Proxy settings for the browser. Defaults to None.
|
||||
options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
|
||||
|
||||
Returns:
|
||||
WebDriver: An instance of WebDriver configured with the specified options.
|
||||
"""
|
||||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install Webdriver packages | pip install -U g4f[webdriver]')
|
||||
browser = find_chrome_executable()
|
||||
if browser is None:
|
||||
raise MissingRequirementsError('Install "Google Chrome" browser')
|
||||
if user_data_dir is None:
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
if user_data_dir and debug.logging:
|
||||
print("Open browser with config dir:", user_data_dir)
|
||||
if not options:
|
||||
options = ChromeOptions()
|
||||
if proxy:
|
||||
options.add_argument(f'--proxy-server={proxy}')
|
||||
# Check for system driver in docker
|
||||
driver = which('chromedriver') or '/usr/bin/chromedriver'
|
||||
if not path.isfile(driver) or not access(driver, R_OK):
|
||||
driver = None
|
||||
return Chrome(
|
||||
options=options,
|
||||
user_data_dir=user_data_dir,
|
||||
driver_executable_path=driver,
|
||||
browser_executable_path=browser,
|
||||
headless=headless,
|
||||
patcher_force_close=True
|
||||
)
|
||||
|
||||
def get_driver_cookies(driver: WebDriver) -> Cookies:
|
||||
"""
|
||||
Retrieves cookies from the specified WebDriver.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): The WebDriver instance from which to retrieve cookies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing cookies with their names as keys and values as cookie values.
|
||||
"""
|
||||
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
|
||||
|
||||
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
|
||||
"""
|
||||
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): The WebDriver to use for accessing the URL.
|
||||
url (str): The URL to access.
|
||||
timeout (int): Time in seconds to wait for the page to load.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error while bypassing Cloudflare or loading the page.
|
||||
"""
|
||||
driver.get(url)
|
||||
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
|
||||
if debug.logging:
|
||||
print("Cloudflare protection detected:", url)
|
||||
|
||||
# Open website in a new tab
|
||||
element = driver.find_element(By.ID, "challenge-body-text")
|
||||
driver.execute_script(f"""
|
||||
arguments[0].addEventListener('click', () => {{
|
||||
window.open(arguments[1]);
|
||||
}});
|
||||
""", element, url)
|
||||
element.click()
|
||||
time.sleep(5)
|
||||
|
||||
# Switch to the new tab and close the old tab
|
||||
original_window = driver.current_window_handle
|
||||
for window_handle in driver.window_handles:
|
||||
if window_handle != original_window:
|
||||
driver.close()
|
||||
driver.switch_to.window(window_handle)
|
||||
break
|
||||
|
||||
# Click on the challenge button in the iframe
|
||||
try:
|
||||
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
|
||||
WebDriverWait(driver, 5).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
|
||||
).click()
|
||||
except NoSuchElementException:
|
||||
...
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error bypassing Cloudflare: {str(e).splitlines()[0]}")
|
||||
#driver.switch_to.default_content()
|
||||
driver.switch_to.window(window_handle)
|
||||
driver.execute_script("document.href = document.href;")
|
||||
WebDriverWait(driver, timeout).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
|
||||
)
|
||||
|
||||
class WebDriverSession:
|
||||
"""
|
||||
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
webdriver: WebDriver = None,
|
||||
user_data_dir: str = None,
|
||||
headless: bool = False,
|
||||
virtual_display: bool = False,
|
||||
proxy: str = None,
|
||||
options: ChromeOptions = None
|
||||
):
|
||||
"""
|
||||
Initializes a new instance of the WebDriverSession.
|
||||
|
||||
Args:
|
||||
webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
|
||||
user_data_dir (str, optional): Directory for user data. Defaults to None.
|
||||
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
|
||||
virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
|
||||
proxy (str, optional): Proxy settings for the browser. Defaults to None.
|
||||
options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
|
||||
"""
|
||||
self.webdriver = webdriver
|
||||
self.user_data_dir = user_data_dir
|
||||
self.headless = headless
|
||||
self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
|
||||
self.proxy = proxy
|
||||
self.options = options
|
||||
self.default_driver = None
|
||||
|
||||
def reopen(
|
||||
self,
|
||||
user_data_dir: str = None,
|
||||
headless: bool = False,
|
||||
virtual_display: bool = False
|
||||
) -> WebDriver:
|
||||
"""
|
||||
Reopens the WebDriver session with new settings.
|
||||
|
||||
Args:
|
||||
user_data_dir (str, optional): Directory for user data. Defaults to current value.
|
||||
headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
|
||||
virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
|
||||
|
||||
Returns:
|
||||
WebDriver: The reopened WebDriver instance.
|
||||
"""
|
||||
user_data_dir = user_data_dir or self.user_data_dir
|
||||
if self.default_driver:
|
||||
self.default_driver.quit()
|
||||
if not virtual_display and self.virtual_display:
|
||||
self.virtual_display.stop()
|
||||
self.virtual_display = None
|
||||
self.default_driver = get_browser(user_data_dir, headless, self.proxy)
|
||||
return self.default_driver
|
||||
|
||||
def __enter__(self) -> WebDriver:
|
||||
"""
|
||||
Context management method for entering a session. Initializes and returns a WebDriver instance.
|
||||
|
||||
Returns:
|
||||
WebDriver: An instance of WebDriver for this session.
|
||||
"""
|
||||
if self.webdriver:
|
||||
return self.webdriver
|
||||
if self.virtual_display:
|
||||
self.virtual_display.start()
|
||||
self.default_driver = get_browser(self.user_data_dir, self.headless, self.proxy, self.options)
|
||||
return self.default_driver
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Context management method for exiting a session. Closes and quits the WebDriver.
|
||||
|
||||
Args:
|
||||
exc_type: Exception type.
|
||||
exc_val: Exception value.
|
||||
exc_tb: Exception traceback.
|
||||
|
||||
Note:
|
||||
Closes the WebDriver and stops the virtual display if used.
|
||||
"""
|
||||
if self.default_driver:
|
||||
try:
|
||||
self.default_driver.close()
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error closing WebDriver: {str(e).splitlines()[0]}")
|
||||
finally:
|
||||
self.default_driver.quit()
|
||||
if self.virtual_display:
|
||||
self.virtual_display.stop()
|
||||
|
||||
def element_send_text(element: WebElement, text: str) -> None:
|
||||
script = "arguments[0].innerText = arguments[1];"
|
||||
element.parent.execute_script(script, element, text)
|
||||
element.send_keys(Keys.ENTER)
|
||||
10
setup.py
10
setup.py
|
|
@ -84,6 +84,16 @@ EXTRA_REQUIRE = {
|
|||
],
|
||||
"local": [
|
||||
"gpt4all"
|
||||
],
|
||||
"files": [
|
||||
"spacy",
|
||||
"filesplit",
|
||||
"beautifulsoup4",
|
||||
"pypdf2",
|
||||
"docx",
|
||||
"odfpy",
|
||||
"ebooklib",
|
||||
"openpyxl",
|
||||
]
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue