Add filessupport, scrape and refine your data

Remove Webdriver usages
Add continue messages for other providers
This commit is contained in:
Heiner Lohaus 2025-01-01 04:20:02 +01:00
parent 90360ccfa6
commit 7893a0835e
33 changed files with 1155 additions and 559 deletions

View file

@ -5,7 +5,7 @@ import asyncio
from unittest.mock import MagicMock
from g4f.errors import MissingRequirementsError
try:
from g4f.gui.server.backend import Backend_Api
from g4f.gui.server.backend_api import Backend_Api
has_requirements = True
except:
has_requirements = False

View file

@ -14,7 +14,7 @@ from ..typing import AsyncResult, Messages, ImagesType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse, to_data_uri
from ..cookies import get_cookies_dir
from ..web_search import get_search_message
from ..tools.web_search import get_search_message
from .helper import format_prompt
from .. import debug

View file

@ -5,8 +5,10 @@ import json
from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies, DEFAULT_HEADERS, has_nodriver, has_curl_cffi
from ..errors import ResponseStatusError
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
from ..providers.response import FinishReason
from ..errors import ResponseStatusError, ModelNotFoundError
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cloudflare AI"
@ -70,7 +72,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
else:
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
model = cls.get_model(model)
try:
model = cls.get_model(model)
except ModelNotFoundError:
pass
data = {
"messages": messages,
"lora": None,
@ -89,6 +94,7 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
except ResponseStatusError:
cls._args = None
raise
reason = None
async for line in response.iter_lines():
if line.startswith(b'data: '):
if line == b'data: [DONE]':
@ -97,5 +103,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
content = json.loads(line[6:].decode())
if content.get("response") and content.get("response") != '</s>':
yield content['response']
reason = "max_tokens"
elif content.get("response") == '':
reason = "stop"
except Exception:
continue
continue
if reason is not None:
yield FinishReason(reason)

View file

@ -15,7 +15,6 @@ except ImportError:
from ..helper import get_connector
from ...errors import MissingRequirementsError, RateLimitError
from ...webdriver import WebDriver, get_driver_cookies, get_browser
BING_URL = "https://www.bing.com"
TIMEOUT_LOGIN = 1200
@ -31,39 +30,6 @@ BAD_IMAGES = [
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
]
def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
"""
Waits for the user to log in within a given timeout period.
Args:
driver (WebDriver): Webdriver for browser automation.
timeout (int): Maximum waiting time in seconds.
Raises:
RuntimeError: If the login process exceeds the timeout.
"""
driver.get(f"{BING_URL}/")
start_time = time.time()
while not driver.get_cookie("_U"):
if time.time() - start_time > timeout:
raise RuntimeError("Timeout error")
time.sleep(0.5)
def get_cookies_from_browser(proxy: str = None) -> dict[str, str]:
"""
Retrieves cookies from the browser using webdriver.
Args:
proxy (str, optional): Proxy configuration.
Returns:
dict[str, str]: Retrieved cookies.
"""
with get_browser(proxy=proxy) as driver:
wait_for_login(driver)
time.sleep(1)
return get_driver_cookies(driver)
def create_session(cookies: Dict[str, str], proxy: str = None, connector: BaseConnector = None) -> ClientSession:
"""
Creates a new client session with specified cookies and headers.

View file

@ -102,22 +102,15 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
if "config" in model_data and "model_type" in model_data["config"]:
model_type = model_data["config"]["model_type"]
debug.log(f"Model type: {model_type}")
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
inputs = format_prompt(messages, do_continue=do_continue)
elif model_type in ("mistral"):
inputs = format_prompt_mistral(messages, do_continue)
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
inputs = format_prompt_custom(messages, eos_token, do_continue)
elif eos_token == "<|im_end|>":
inputs = format_prompt_qwen(messages, do_continue)
elif eos_token == "<|eot_id|>":
inputs = format_prompt_llama(messages, do_continue)
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"Inputs len: {len(inputs)}")
if len(inputs) > 4096:
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
else:
inputs = format_prompt(messages, do_continue=do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_new_tokens >= 1024:
params["max_new_tokens"] = 512
payload = {"inputs": inputs, "parameters": params, "stream": stream}
@ -187,4 +180,23 @@ def format_prompt_custom(messages: Messages, end_token: str = "</s>", do_continu
]) + ("" if do_continue else "<|assistant|>\n")
if do_continue:
return prompt[:-len(end_token + "\n")]
return prompt
return prompt
def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str:
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
inputs = format_prompt(messages, do_continue=do_continue)
elif model_type in ("mistral"):
inputs = format_prompt_mistral(messages, do_continue)
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
inputs = format_prompt_custom(messages, eos_token, do_continue)
elif eos_token == "<|im_end|>":
inputs = format_prompt_qwen(messages, do_continue)
elif eos_token == "<|eot_id|>":
inputs = format_prompt_llama(messages, do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
return inputs

View file

@ -404,7 +404,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
data["conversation_id"] = conversation.conversation_id
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
if action != "continue":
data["parent_message_id"] = conversation.parent_message_id
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
conversation.parent_message_id = None
messages = messages if conversation_id is None else [messages[-1]]
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
@ -604,7 +604,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"api_key": cls._api_key,
"proof_token": RequestConfig.proof_token,
"cookies": RequestConfig.cookies,
"headers": RequestConfig.headers
})
@classmethod
@ -636,6 +635,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
page = await browser.get(cls.url)
user_agent = await page.evaluate("window.navigator.userAgent")
await page.select("#prompt-textarea", 240)
await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'")
await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()")
while True:
if cls._api_key is not None:
break

View file

@ -5,7 +5,6 @@ import time
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession, element_send_text
models = {
"meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
@ -22,7 +21,7 @@ models = {
class Poe(AbstractProvider):
url = "https://poe.com"
working = True
working = False
needs_auth = True
supports_stream = True

View file

@ -5,7 +5,6 @@ import time
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession, element_send_text
models = {
"theb-ai": "TheB.AI",
@ -34,7 +33,7 @@ models = {
class Theb(AbstractProvider):
label = "TheB.AI"
url = "https://beta.theb.ai"
working = True
working = False
supports_stream = True
models = models.keys()

View file

@ -4,8 +4,6 @@ from aiohttp import ClientSession
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider
from ...requests import get_args_from_browser
from ...webdriver import WebDriver
class Aura(AsyncGeneratorProvider):
url = "https://openchat.team"
@ -19,7 +17,7 @@ class Aura(AsyncGeneratorProvider):
proxy: str = None,
temperature: float = 0.5,
max_tokens: int = 8192,
webdriver: WebDriver = None,
webdriver = None,
**kwargs
) -> AsyncResult:
args = get_args_from_browser(cls.url, webdriver, proxy)

View file

@ -5,7 +5,6 @@ import time, json
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession, bypass_cloudflare
class MyShell(AbstractProvider):
url = "https://app.myshell.ai/chat"
@ -21,7 +20,7 @@ class MyShell(AbstractProvider):
stream: bool,
proxy: str = None,
timeout: int = 120,
webdriver: WebDriver = None,
webdriver = None,
**kwargs
) -> CreateResult:
with WebDriverSession(webdriver, "", proxy=proxy) as driver:

View file

@ -12,7 +12,6 @@ except ImportError:
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession, element_send_text
class PerplexityAi(AbstractProvider):
url = "https://www.perplexity.ai"
@ -28,7 +27,7 @@ class PerplexityAi(AbstractProvider):
stream: bool,
proxy: str = None,
timeout: int = 120,
webdriver: WebDriver = None,
webdriver = None,
virtual_display: bool = True,
copilot: bool = False,
**kwargs

View file

@ -6,7 +6,6 @@ from urllib.parse import quote
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession
class Phind(AbstractProvider):
url = "https://www.phind.com"
@ -22,7 +21,7 @@ class Phind(AbstractProvider):
stream: bool,
proxy: str = None,
timeout: int = 120,
webdriver: WebDriver = None,
webdriver = None,
creative_mode: bool = None,
**kwargs
) -> CreateResult:

View file

@ -4,7 +4,6 @@ import time, json, time
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
from ...webdriver import WebDriver, WebDriverSession
class TalkAi(AbstractProvider):
url = "https://talkai.info"
@ -19,7 +18,7 @@ class TalkAi(AbstractProvider):
messages: Messages,
stream: bool,
proxy: str = None,
webdriver: WebDriver = None,
webdriver = None,
**kwargs
) -> CreateResult:
with WebDriverSession(webdriver, "", virtual_display=True, proxy=proxy) as driver:

View file

@ -41,11 +41,12 @@ from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthErr
from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__
from g4f.gui import get_gui_app
from g4f.tools.files import supports_filename, get_streaming
from .stubs import (
ChatCompletionsConfig, ImageGenerationConfig,
ProviderResponseModel, ModelResponseModel,
ErrorResponseModel, ProviderResponseDetailModel,
FileResponseModel, Annotated
FileResponseModel, UploadResponseModel, Annotated
)
logger = logging.getLogger(__name__)
@ -424,6 +425,40 @@ class Api:
read_cookie_files()
return response_data
@self.app.get("/v1/files/{bucket_id}", responses={
HTTP_200_OK: {"content": {
"text/event-stream": {"schema": {"type": "string"}},
"text/plain": {"schema": {"type": "string"}},
}},
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
})
def read_files(request: Request, bucket_id: str, delete_files: bool = True, refine_chunks_with_spacy: bool = False):
bucket_dir = os.path.join(get_cookies_dir(), bucket_id)
event_stream = "text/event-stream" in request.headers.get("accept", "")
if not os.path.isdir(bucket_dir):
return ErrorResponse.from_message("Bucket dir not found", 404)
return StreamingResponse(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), media_type="text/plain")
@self.app.post("/v1/files/{bucket_id}", responses={
HTTP_200_OK: {"model": UploadResponseModel}
})
def upload_files(bucket_id: str, files: List[UploadFile]):
bucket_dir = os.path.join(get_cookies_dir(), bucket_id)
os.makedirs(bucket_dir, exist_ok=True)
filenames = []
for file in files:
try:
filename = os.path.basename(file.filename)
if file and supports_filename(filename):
with open(os.path.join(bucket_dir, filename), 'wb') as f:
shutil.copyfileobj(file.file, f)
filenames.append(filename)
finally:
file.file.close()
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
[f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "url": f"/v1/files/{bucket_id}", "files": filenames}
@self.app.get("/v1/synthesize/{provider}", responses={
HTTP_200_OK: {"content": {"audio/*": {}}},
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},

View file

@ -66,6 +66,10 @@ class ModelResponseModel(BaseModel):
created: int
owned_by: Optional[str]
class UploadResponseModel(BaseModel):
bucket_id: str
url: str
class ErrorResponseModel(BaseModel):
error: ErrorResponseMessageModel
model: Optional[str] = None

View file

@ -7,7 +7,6 @@ import string
import asyncio
import aiohttp
import base64
import json
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
from ..image import ImageResponse, copy_images, images_dir
@ -17,13 +16,13 @@ from ..providers.response import ResponseType, FinishReason, BaseConversation, S
from ..errors import NoImageResponseError
from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator, async_generator_to_list
from ..web_search import get_search_message, do_search
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
from .helper import find_stop, filter_json, filter_none, safe_aclose
from .. import debug
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
@ -38,47 +37,6 @@ except NameError:
except StopAsyncIteration:
raise StopIteration
def validate_arguments(data: dict) -> dict:
if "arguments" in data:
if isinstance(data["arguments"], str):
data["arguments"] = json.loads(data["arguments"])
if not isinstance(data["arguments"], dict):
raise ValueError("Tool function arguments must be a dictionary or a json string")
else:
return filter_none(**data["arguments"])
else:
return {}
async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
if tool_calls is not None:
for tool in tool_calls:
if tool.get("type") == "function":
if tool.get("function", {}).get("name") == "search_tool":
tool["function"]["arguments"] = validate_arguments(tool["function"])
messages = messages.copy()
messages[-1]["content"] = await do_search(
messages[-1]["content"],
**tool["function"]["arguments"]
)
response = async_iter_callback(model=model, messages=messages, **kwargs)
if not hasattr(response, "__aiter__"):
response = to_async_iterator(response)
async for chunk in response:
yield chunk
def iter_run_tools(iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
if tool_calls is not None:
for tool in tool_calls:
if tool.get("type") == "function":
if tool.get("function", {}).get("name") == "search_tool":
tool["function"]["arguments"] = validate_arguments(tool["function"])
messages[-1]["content"] = get_search_message(
messages[-1]["content"],
raise_search_exceptions=True,
**tool["function"]["arguments"]
)
return iter_callback(model=model, messages=messages, **kwargs)
# Synchronous iter_response function
def iter_response(
response: Union[Iterator[Union[str, ResponseType]]],
@ -131,7 +89,8 @@ def iter_response(
break
idx += 1
if usage is None:
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx).get_dict()
finish_reason = "stop" if finish_reason is None else finish_reason
if stream:

View file

@ -180,7 +180,7 @@ def read_cookie_files(dirPath: str = None):
except json.JSONDecodeError:
# Error: not a json file!
continue
if not isinstance(cookieFile, list):
if not isinstance(cookieFile, list) or not isinstance(cookieFile[0], dict) or "domain" not in cookieFile[0]:
continue
debug.log(f"Read cookie file: {path}")
new_cookies = {}

View file

@ -1,9 +1,9 @@
from ..errors import MissingRequirementsError
try:
from .server.app import app
from .server.website import Website
from .server.backend import Backend_Api
from .server.backend_api import Backend_Api
from .server.app import create_app
import_error = None
except ImportError as e:
import_error = e
@ -11,6 +11,7 @@ except ImportError as e:
def get_gui_app():
if import_error is not None:
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
app = create_app()
site = Website(app)
for route in site.routes:
@ -36,7 +37,7 @@ def run_gui(host: str = '0.0.0.0', port: int = 8080, debug: bool = False) -> Non
'debug': debug
}
get_gui_app()
app = get_gui_app()
print(f"Running on port {config['port']}")
app.run(**config)

View file

@ -136,6 +136,11 @@
<input type="checkbox" id="download_images" checked/>
<label for="download_images" class="toogle" title="Download and save generated images to /generated_images"></label>
</div>
<div class="field">
<span class="label">Refine files with spaCy</span>
<input type="checkbox" id="refine" checked/>
<label for="refine" class="toogle" title=""></label>
</div>
<div class="field box">
<label for="message-input-height" class="label" title="">Input max. height</label>
<input type="number" id="message-input-height" value="200"/>
@ -258,7 +263,7 @@
<i class="fa-solid fa-camera"></i>
</label>
<label class="file-label" for="file">
<input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md" required/>
<input type="file" id="file" name="file" accept=".txt, .html, .xml, .json, .js, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md, .pdf, .docx, .odt, .epub, .xlsx, .zip" required multiple/>
<i class="fa-solid fa-paperclip"></i>
</label>
<label class="micro-label" for="micro">

View file

@ -423,8 +423,10 @@ body:not(.white) a:visited{
.message .count .fa-print.clicked,
.message .count .fa-rotate.clicked,
.message .count .fa-volume-high.active,
.message .count .fa-file-export.clicked,
.message .continue_button.clicked,
.message .regenerate_button.clicked {
.message .regenerate_button.clicked,
.message .options_button.clicked {
color: var(--accent);
}
@ -530,7 +532,7 @@ body:not(.white) a:visited{
right: 0;
}
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button {
.stop_generating button, .toolbar .regenerate button, button.regenerate_button, button.continue_button, button.options_button {
backdrop-filter: blur(20px);
-webkit-backdrop-filter: blur(20px);
background-color: var(--blur-bg);
@ -553,11 +555,11 @@ body:not(.white) a:visited{
right: auto;
}
.toolbar .regenerate span, .regenerate_button span, .continue_button span {
.toolbar .regenerate span, .regenerate_button span, .continue_button span, .options_button div {
display: none;
}
.regenerate_button span, .continue_button span {
.regenerate_button span, .continue_button span, .options_button div {
position: absolute;
height: 20px;
width: 100px;
@ -568,6 +570,27 @@ body:not(.white) a:visited{
padding: 6px;
}
.options_button div {
display: none;
flex-direction: row;
height: 36px;
width: 120px;
margin-right: 130px;
z-index: 1005;
padding: 10px;
margin-top: -4px;
}
.options_button div span{
height: 20px;
width: 22px;
}
.options_button:hover div {
display: flex;
}
.regenerate_button:hover span, .continue_button:hover span {
display: block;
transition: all 0.3s;
@ -786,7 +809,7 @@ select {
border-radius: 25px;
}
.buttons button, button.regenerate_button, button.continue_button {
.buttons button, button.regenerate_button, button.continue_button, button.options_button {
border-radius: 8px;
backdrop-filter: blur(20px);
cursor: pointer;
@ -796,7 +819,7 @@ select {
padding: 8px;
}
button.regenerate_button, button.continue_button {
button.regenerate_button, button.continue_button, button.options_button {
display: flex;
margin-top: -8px;
margin-left: 2px;

View file

@ -68,7 +68,7 @@ if (window.markdownit) {
function filter_message(text) {
return text.replaceAll(
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
)
).replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
}
function fallback_clipboard (text) {
@ -204,7 +204,8 @@ const register_message_buttons = async () => {
el.dataset.click = "true";
el.addEventListener("click", async () => {
let message_el = get_message_el(el);
const copyText = await get_message(window.conversation_id, message_el.dataset.index);
let response = await fetch(message_el.dataset.object_url);
let copyText = await response.text();
try {
if (!navigator.clipboard) {
throw new Error("navigator.clipboard: Clipboard API unavailable.");
@ -221,6 +222,24 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
let message_el = get_message_el(el);
const elem = window.document.createElement('a');
let filename = `chat ${new Date().toLocaleString()}.md`.replaceAll(":", "-");
elem.href = message_el.dataset.object_url;
elem.download = filename;
document.body.appendChild(elem);
elem.click();
document.body.removeChild(elem);
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
})
}
});
document.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
@ -362,11 +381,6 @@ const handle_ask = async () => {
</div>
<div class="count">
${count_words_and_tokens(message, get_selected_model()?.value)}
<i class="fa-solid fa-volume-high"></i>
<i class="fa-regular fa-clipboard"></i>
<a><i class="fa-brands fa-whatsapp"></i></a>
<i class="fa-solid fa-print"></i>
<i class="fa-solid fa-rotate"></i>
</div>
</div>
</div>
@ -484,11 +498,15 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) =>
}
}
messages.forEach((new_message) => {
messages.forEach((new_message, i) => {
// Copy message first
new_message = { ...new_message };
// Include last message, if do_continue
if (i + 1 == messages.length && do_continue) {
delete new_message.regenerate;
}
// Include only not regenerated messages
if (new_message && !new_message.regenerate) {
// Copy message first
new_message = { ...new_message };
// Remove generated images from history
new_message.content = filter_message(new_message.content);
// Remove internal fields
@ -707,8 +725,8 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
message_storage[message_id] = "";
stop_generating.classList.remove("stop_generating-hidden");
if (message_index == -1 && !regenerate) {
await scroll_to_bottom();
if (message_index == -1) {
await lazy_scroll_to_bottom();
}
let count_total = message_box.querySelector('.count_total');
@ -750,9 +768,10 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
inner: content_el.querySelector('.content_inner'),
count: content_el.querySelector('.count'),
update_timeouts: [],
message_index: message_index,
}
if (message_index == -1 && !regenerate) {
await scroll_to_bottom();
if (message_index == -1) {
await lazy_scroll_to_bottom();
}
try {
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
@ -813,7 +832,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
let cursorDiv = message_el.querySelector(".cursor");
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
if (message_index == -1) {
await scroll_to_bottom();
await lazy_scroll_to_bottom();
}
await safe_remove_cancel_button();
await register_message_buttons();
@ -826,6 +845,12 @@ async function scroll_to_bottom() {
message_box.scrollTop = message_box.scrollHeight;
}
async function lazy_scroll_to_bottom() {
if (message_box.scrollHeight - message_box.scrollTop < 2 * message_box.clientHeight) {
await scroll_to_bottom();
}
}
const clear_conversations = async () => {
const elements = box_conversations.childNodes;
let index = elements.length;
@ -971,7 +996,23 @@ const load_conversation = async (conversation_id, scroll=true) => {
} else {
buffer = "";
}
buffer += item.content;
buffer = buffer.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
let lines = buffer.trim().split("\n");
let lastLine = lines[lines.length - 1];
let newContent = item.content;
if (newContent.startsWith("```\n")) {
newContent = item.content.substring(4);
}
if (newContent.startsWith(lastLine)) {
newContent = newContent.substring(lastLine.length);
} else {
let words = buffer.trim().split(" ");
let lastWord = words[words.length - 1];
if (newContent.startsWith(lastWord)) {
newContent = newContent.substring(lastWord.length);
}
}
buffer += newContent;
last_model = item.provider?.model;
providers.push(item.provider?.name);
let next_i = parseInt(i) + 1;
@ -993,28 +1034,74 @@ const load_conversation = async (conversation_id, scroll=true) => {
synthesize_params = (new URLSearchParams(synthesize_params)).toString();
let synthesize_url = `/backend-api/v2/synthesize/${synthesize_provider}?${synthesize_params}`;
const file = new File([buffer], 'message.md', {type: 'text/plain'});
const objectUrl = URL.createObjectURL(file);
let add_buttons = [];
// Always add regenerate button
add_buttons.push(`<button class="regenerate_button">
<span>Regenerate</span>
<i class="fa-solid fa-rotate"></i>
</button>`);
// Add continue button if possible
actions = ["variant"]
if (item.finish && item.finish.actions) {
item.finish.actions.forEach((action) => {
if (action == "continue") {
if (messages.length >= i - 1) {
add_buttons.push(`<button class="continue_button">
<span>Continue</span>
<i class="fa-solid fa-wand-magic-sparkles"></i>
</button>`);
}
actions = item.finish.actions
}
if (!("continue" in actions)) {
let reason = "stop";
// Read finish reason from conversation
if (item.finish && item.finish.reason) {
reason = item.finish.reason;
}
let lines = buffer.trim().split("\n");
let lastLine = lines[lines.length - 1];
// Has a stop or error token at the end
if (lastLine.endsWith("[aborted]") || lastLine.endsWith("[error]")) {
reason = "error";
// Has an even number of start or end code tags
} else if (buffer.split("```").length - 1 % 2 === 1) {
reason = "error";
// Has a end token at the end
} else if (lastLine.endsWith("```") || lastLine.endsWith(".") || lastLine.endsWith("?") || lastLine.endsWith("!")
|| lastLine.endsWith('"') || lastLine.endsWith("'") || lastLine.endsWith(")")
|| lastLine.endsWith(">") || lastLine.endsWith("]") || lastLine.endsWith("}") ) {
reason = "stop"
} else {
// Has an emoji at the end
const regex = /\p{Emoji}$/u;
if (regex.test(lastLine)) {
reason = "stop"
}
});
}
if (reason == "max_tokens" || reason == "error") {
actions.push("continue")
}
}
add_buttons.push(`<button class="options_button">
<div>
<span><i class="fa-brands fa-whatsapp"></i></span>
<span><i class="fa-solid fa-volume-high"></i></i></span>
<span><i class="fa-solid fa-print"></i></span>
<span><i class="fa-solid fa-file-export"></i></span>
<span><i class="fa-regular fa-clipboard"></i></span>
</div>
<i class="fa-solid fa-plus"></i>
</button>`);
if (actions.includes("variant")) {
add_buttons.push(`<button class="regenerate_button">
<span>Regenerate</span>
<i class="fa-solid fa-rotate"></i>
</button>`);
}
if (actions.includes("continue")) {
if (messages.length >= i - 1) {
add_buttons.push(`<button class="continue_button">
<span>Continue</span>
<i class="fa-solid fa-wand-magic-sparkles"></i>
</button>`);
}
}
elements.push(`
<div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}" data-synthesize_url="${synthesize_url}">
<div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}" data-object_url="${objectUrl}" data-synthesize_url="${synthesize_url}">
<div class="${item.role}">
${item.role == "assistant" ? gpt_image : user_image}
<i class="fa-solid fa-xmark"></i>
@ -1028,12 +1115,6 @@ const load_conversation = async (conversation_id, scroll=true) => {
<div class="content_inner">${markdown_render(buffer)}</div>
<div class="count">
${count_words_and_tokens(buffer, next_provider?.model)}
<span>
<i class="fa-solid fa-volume-high"></i>
<i class="fa-regular fa-clipboard"></i>
<a><i class="fa-brands fa-whatsapp"></i></a>
<i class="fa-solid fa-print"></i>
</span>
${add_buttons.join("")}
</div>
</div>
@ -1444,11 +1525,8 @@ function update_message(content_map, message_id, content = null) {
content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner);
if (!content_map.container.classList.contains("regenerate")) {
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 200) {
window.scrollTo(0, 0);
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
}
if (content_map.message_index == -1) {
lazy_scroll_to_bottom();
}
content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId));
content_map.update_timeouts = [];
@ -1711,19 +1789,76 @@ async function upload_cookies() {
fileInput.value = "";
}
function formatFileSize(bytes) {
const units = ['B', 'KB', 'MB', 'GB'];
let unitIndex = 0;
while (bytes >= 1024 && unitIndex < units.length - 1) {
bytes /= 1024;
unitIndex++;
}
return `${bytes.toFixed(2)} ${units[unitIndex]}`;
}
async function upload_files(fileInput) {
const paperclip = document.querySelector(".user-input .fa-paperclip");
const bucket_id = uuid();
const formData = new FormData();
Array.from(fileInput.files).forEach(file => {
formData.append('files[]', file);
});
paperclip.classList.add("blink");
await fetch("/backend-api/v2/files/" + bucket_id, {
method: 'POST',
body: formData
});
let do_refine = document.getElementById("refine").checked;
function connectToSSE(url) {
const eventSource = new EventSource(url);
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.error) {
inputCount.innerText = `Error: ${data.error.message}`;
} else if (data.action == "load") {
inputCount.innerText = `Read data: ${formatFileSize(data.size)}`;
} else if (data.action == "refine") {
inputCount.innerText = `Refine data: ${formatFileSize(data.size)}`;
} else if (data.action == "done") {
if (do_refine) {
do_refine = false;
connectToSSE(`/backend-api/v2/files/${bucket_id}?refine_chunks_with_spacy=true`);
return;
}
inputCount.innerText = "Files are loaded successfully";
messageInput.value += (messageInput.value ? "\n" : "") + JSON.stringify({bucket_id: bucket_id}) + "\n";
paperclip.classList.remove("blink");
fileInput.value = "";
delete fileInput.dataset.text;
}
};
eventSource.onerror = (event) => {
eventSource.close();
paperclip.classList.remove("blink");
}
}
connectToSSE(`/backend-api/v2/files/${bucket_id}`);
}
fileInput.addEventListener('change', async (event) => {
if (fileInput.files.length) {
type = fileInput.files[0].name.split('.').pop()
if (type == "har") {
return await upload_cookies();
} else if (type != "json") {
await upload_files(fileInput);
}
fileInput.dataset.type = type
const reader = new FileReader();
reader.addEventListener('load', async (event) => {
fileInput.dataset.text = event.target.result;
if (type == "json") {
if (type == "json") {
const reader = new FileReader();
reader.addEventListener('load', async (event) => {
fileInput.dataset.text = event.target.result;
const data = JSON.parse(fileInput.dataset.text);
if ("g4f" in data.options) {
if (data.options && "g4f" in data.options) {
let count = 0;
Object.keys(data).forEach(key => {
if (key != "options" && !localStorage.getItem(key)) {
@ -1736,11 +1871,23 @@ fileInput.addEventListener('change', async (event) => {
fileInput.value = "";
inputCount.innerText = `${count} Conversations were imported successfully`;
} else {
await upload_cookies();
is_cookie_file = false;
if (Array.isArray(data)) {
data.forEach((item) => {
if (item.domain && item.name && item.value) {
is_cookie_file = true;
}
});
}
if (is_cookie_file) {
await upload_cookies();
} else {
await upload_files(fileInput);
}
}
}
});
reader.readAsText(fileInput.files[0]);
});
reader.readAsText(fileInput.files[0]);
}
} else {
delete fileInput.dataset.text;
}

View file

@ -1,5 +1,6 @@
from .gui_parser import gui_parser
from ..cookies import read_cookie_files
from g4f.gui import run_gui
import g4f.cookies
import g4f.debug
@ -8,7 +9,6 @@ def run_gui_args(args):
g4f.debug.logging = True
if not args.ignore_cookie_files:
read_cookie_files()
from g4f.gui import run_gui
host = args.host
port = args.port
debug = args.debug

View file

@ -7,17 +7,17 @@ from typing import Iterator
from flask import send_from_directory
from inspect import signature
from g4f import version, models
from g4f import ChatCompletion, get_model_and_provider
from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderUtils, __providers__
from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.retry_provider import IterListProvider
from g4f.providers.response import BaseConversation, JsonConversation, FinishReason
from g4f.providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
from g4f.client.service import convert_to_provider
from g4f import debug
from ...errors import VersionNotFoundError
from ...image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin
from ...providers.retry_provider import IterListProvider
from ...providers.response import BaseConversation, JsonConversation, FinishReason
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
from ... import version, models
from ... import ChatCompletion, get_model_and_provider
from ... import debug
logger = logging.getLogger(__name__)
conversations: dict[dict[str, BaseConversation]] = {}
@ -90,16 +90,28 @@ class Api:
api_key = json_data.get("api_key")
if api_key is not None:
kwargs["api_key"] = api_key
kwargs["tool_calls"] = [{
"function": {
"name": "bucket_tool"
},
"type": "function"
}]
do_web_search = json_data.get('web_search')
if do_web_search and provider:
provider_handler = convert_to_provider(provider)
if hasattr(provider_handler, "get_parameters"):
if "web_search" in provider_handler.get_parameters():
kwargs['web_search'] = True
do_web_search = False
if do_web_search:
from ...web_search import get_search_message
messages[-1]["content"] = get_search_message(messages[-1]["content"])
kwargs["tool_calls"].append({
"function": {
"name": "safe_search_tool"
},
"type": "function"
})
action = json_data.get('action')
if action == "continue":
kwargs["tool_calls"].append({
"function": {
"name": "continue_tool"
},
"type": "function"
})
conversation = json_data.get("conversation")
if conversation is not None:
kwargs["conversation"] = JsonConversation(**conversation)
@ -139,7 +151,7 @@ class Api:
logging=False
)
params = {
**provider_handler.get_parameters(as_json=True),
**(provider_handler.get_parameters(as_json=True) if hasattr(provider_handler, "get_parameters") else {}),
"model": model,
"messages": kwargs.get("messages"),
"web_search": kwargs.get("web_search")
@ -153,7 +165,7 @@ class Api:
yield self._format_json("parameters", params)
first = True
try:
result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider_handler})
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler})
for chunk in result:
if first:
first = False

View file

@ -1,9 +1,9 @@
import sys, os
from flask import Flask
if getattr(sys, 'frozen', False):
template_folder = os.path.join(sys._MEIPASS, "client")
else:
template_folder = "../client"
app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
def create_app() -> Flask:
if getattr(sys, 'frozen', False):
template_folder = os.path.join(sys._MEIPASS, "client")
else:
template_folder = "../client"
return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")

View file

@ -1,17 +1,22 @@
from __future__ import annotations
import json
import flask
import os
import logging
import asyncio
from flask import Flask, request, jsonify
import shutil
from flask import Flask, Response, request, jsonify
from typing import Generator
from pathlib import Path
from werkzeug.utils import secure_filename
from g4f.image import is_allowed_extension, to_image
from g4f.client.service import convert_to_provider
from g4f.providers.asyncio import to_sync_generator
from g4f.errors import ProviderNotFoundError
from g4f.cookies import get_cookies_dir
from ...image import is_allowed_extension, to_image
from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...errors import ProviderNotFoundError
from ...cookies import get_cookies_dir
from .api import Api
logger = logging.getLogger(__name__)
@ -96,6 +101,85 @@ class Backend_Api(Api):
}
}
@app.route('/backend-api/v2/buckets', methods=['GET'])
def list_buckets():
try:
buckets = get_buckets()
if buckets is None:
return jsonify({"error": {"message": "Error accessing bucket directory"}}), 500
sanitized_buckets = [secure_filename(b) for b in buckets]
return jsonify(sanitized_buckets), 200
except Exception as e:
return jsonify({"error": {"message": str(e)}}), 500
@app.route('/backend-api/v2/files/<bucket_id>', methods=['GET', 'DELETE'])
def manage_files(bucket_id: str):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(secure_filename(bucket_id))
if not os.path.isdir(bucket_dir):
return jsonify({"error": {"message": "Bucket directory not found"}}), 404
if request.method == 'DELETE':
try:
shutil.rmtree(bucket_dir)
return jsonify({"message": "Bucket deleted successfully"}), 200
except OSError as e:
return jsonify({"error": {"message": f"Error deleting bucket: {str(e)}"}}), 500
except Exception as e:
return jsonify({"error": {"message": str(e)}}), 500
delete_files = request.args.get('delete_files', True)
refine_chunks_with_spacy = request.args.get('refine_chunks_with_spacy', False)
event_stream = 'text/event-stream' in request.headers.get('Accept', '')
mimetype = "text/event-stream" if event_stream else "text/plain";
return Response(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), mimetype=mimetype)
@self.app.route('/backend-api/v2/files/<bucket_id>', methods=['POST'])
def upload_files(bucket_id: str):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
os.makedirs(bucket_dir, exist_ok=True)
filenames = []
for file in request.files.getlist('files[]'):
try:
filename = secure_filename(file.filename)
if supports_filename(filename):
with open(os.path.join(bucket_dir, filename), 'wb') as f:
shutil.copyfileobj(file.stream, f)
filenames.append(filename)
finally:
file.stream.close()
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
[f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "files": filenames}
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
filename = secure_filename(filename)
bucket_path = Path(bucket_dir)
if not supports_filename(filename):
return jsonify({"error": {"message": f"File type not allowed"}}), 400
if not bucket_path.exists():
bucket_path.mkdir(parents=True, exist_ok=True)
try:
file_path = bucket_path / filename
file_data = request.get_data()
if not file_data:
return jsonify({"error": {"message": "No file data received"}}), 400
with open(str(file_path), 'wb') as f:
f.write(file_data)
return jsonify({"message": f"File '{filename}' uploaded successfully to bucket '{bucket_id}'"}), 201
except Exception as e:
return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
def upload_cookies(self):
file = None
if "file" in request.files:

View file

@ -1,3 +1,3 @@
from __future__ import annotations
from ...web_search import SearchResults, search, get_search_message
from ...tools.web_search import SearchResults, search, get_search_message

View file

@ -32,8 +32,6 @@ except ImportError:
from .. import debug
from .raise_for_status import raise_for_status
from ..webdriver import WebDriver, WebDriverSession
from ..webdriver import bypass_cloudflare, get_driver_cookies
from ..errors import MissingRequirementsError
from ..typing import Cookies
from .defaults import DEFAULT_HEADERS, WEBVIEW_HAEDERS
@ -66,69 +64,6 @@ async def get_args_from_webview(url: str) -> dict:
window.destroy()
return {"headers": headers, "cookies": cookies}
def get_args_from_browser(
url: str,
webdriver: WebDriver = None,
proxy: str = None,
timeout: int = 120,
do_bypass_cloudflare: bool = True,
virtual_display: bool = False
) -> dict:
"""
Create a Session object using a WebDriver to handle cookies and headers.
Args:
url (str): The URL to navigate to using the WebDriver.
webdriver (WebDriver, optional): The WebDriver instance to use.
proxy (str, optional): Proxy server to use for the Session.
timeout (int, optional): Timeout in seconds for the WebDriver.
Returns:
Session: A Session object configured with cookies and headers from the WebDriver.
"""
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=virtual_display) as driver:
if do_bypass_cloudflare:
bypass_cloudflare(driver, url, timeout)
headers = {
**DEFAULT_HEADERS,
'referer': url,
}
if not hasattr(driver, "requests"):
headers["user-agent"] = driver.execute_script("return navigator.userAgent")
else:
for request in driver.requests:
if request.url.startswith(url):
for key, value in request.headers.items():
if key in (
"accept-encoding",
"accept-language",
"user-agent",
"sec-ch-ua",
"sec-ch-ua-platform",
"sec-ch-ua-arch",
"sec-ch-ua-full-version",
"sec-ch-ua-platform-version",
"sec-ch-ua-bitness"
):
headers[key] = value
break
cookies = get_driver_cookies(driver)
return {
'cookies': cookies,
'headers': headers,
}
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
args = get_args_from_browser(url, webdriver, proxy, timeout)
return Session(
**args,
proxies={"https": proxy, "http": proxy},
timeout=timeout,
impersonate="chrome"
)
def get_cookie_params_from_dict(cookies: Cookies, url: str = None, domain: str = None) -> list[CookieParam]:
[CookieParam.from_json({
"name": key,

View file

@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
return
text = await response.text()
if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
if message == "HTML content":
if response.status == 520:
message = "Unknown error (Cloudflare)"

511
g4f/tools/files.py Normal file
View file

@ -0,0 +1,511 @@
from __future__ import annotations
import os
import json
from pathlib import Path
from typing import Iterator, Optional
from aiohttp import ClientSession, ClientError, ClientResponse, ClientTimeout
import urllib.parse
import time
import zipfile
import asyncio
import hashlib
import base64
try:
from werkzeug.utils import secure_filename
except ImportError:
secure_filename = os.path.basename
try:
import PyPDF2
from PyPDF2.errors import PdfReadError
has_pypdf2 = True
except ImportError:
has_pypdf2 = False
try:
import pdfplumber
has_pdfplumber = True
except ImportError:
has_pdfplumber = False
try:
from pdfminer.high_level import extract_text
has_pdfminer = True
except ImportError:
has_pdfminer = False
try:
from docx import Document
has_docx = True
except ImportError:
has_docx = False
try:
import docx2txt
has_docx2txt = True
except ImportError:
has_docx2txt = False
try:
from odf.opendocument import load
from odf.text import P
has_odfpy = True
except ImportError:
has_odfpy = False
try:
import ebooklib
from ebooklib import epub
has_ebooklib = True
except ImportError:
has_ebooklib = False
try:
import pandas as pd
has_openpyxl = True
except ImportError:
has_openpyxl = False
try:
import spacy
has_spacy = True
except:
has_spacy = False
try:
from bs4 import BeautifulSoup
has_beautifulsoup4 = True
except ImportError:
has_beautifulsoup4 = False
from .web_search import scrape_text
from ..cookies import get_cookies_dir
from ..requests.aiohttp import get_connector
from ..errors import MissingRequirementsError
from .. import debug
PLAIN_FILE_EXTENSIONS = ["txt", "xml", "json", "js", "har", "sh", "py", "php", "css", "yaml", "sql", "log", "csv", "twig", "md"]
PLAIN_CACHE = "plain.cache"
DOWNLOADS_FILE = "downloads.json"
FILE_LIST = "files.txt"
def supports_filename(filename: str):
if filename.endswith(".pdf"):
if has_pypdf2:
return True
elif has_pdfplumber:
return True
elif has_pdfminer:
return True
raise MissingRequirementsError(f'Install "pypdf2" requirements | pip install -U g4f[files]')
elif filename.endswith(".docx"):
if has_docx:
return True
elif has_docx2txt:
return True
raise MissingRequirementsError(f'Install "docx" requirements | pip install -U g4f[files]')
elif has_odfpy and filename.endswith(".odt"):
return True
elif has_ebooklib and filename.endswith(".epub"):
return True
elif has_openpyxl and filename.endswith(".xlsx"):
return True
elif filename.endswith(".html"):
if not has_beautifulsoup4:
raise MissingRequirementsError(f'Install "beautifulsoup4" requirements | pip install -U g4f[files]')
return True
elif filename.endswith(".zip"):
return True
elif filename.endswith("package-lock.json") and filename != FILE_LIST:
return False
else:
extension = os.path.splitext(filename)[1][1:]
if extension in PLAIN_FILE_EXTENSIONS:
return True
return False
def get_bucket_dir(bucket_id: str):
bucket_dir = os.path.join(get_cookies_dir(), "buckets", bucket_id)
return bucket_dir
def get_buckets():
buckets_dir = os.path.join(get_cookies_dir(), "buckets")
try:
return [d for d in os.listdir(buckets_dir) if os.path.isdir(os.path.join(buckets_dir, d))]
except OSError as e:
return None
def spacy_refine_chunks(source_iterator):
if not has_spacy:
raise MissingRequirementsError(f'Install "spacy" requirements | pip install -U g4f[files]')
nlp = spacy.load("en_core_web_sm")
for page in source_iterator:
doc = nlp(page)
#for chunk in doc.noun_chunks:
# yield " ".join([token.lemma_ for token in chunk if not token.is_stop])
# for token in doc:
# if not token.is_space:
# yield token.lemma_.lower()
# yield " "
sentences = list(doc.sents)
summary = sorted(sentences, key=lambda x: len(x.text), reverse=True)[:2]
for sent in summary:
yield sent.text
def get_filenames(bucket_dir: Path):
files = bucket_dir / FILE_LIST
with files.open('r') as f:
return [filename.strip() for filename in f.readlines()]
def stream_read_files(bucket_dir: Path, filenames: list) -> Iterator[str]:
for filename in filenames:
file_path: Path = bucket_dir / filename
if not file_path.exists() and 0 > file_path.lstat().st_size:
continue
extension = os.path.splitext(filename)[1][1:]
if filename.endswith(".zip"):
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(bucket_dir)
try:
yield from stream_read_files(bucket_dir, [f for f in zip_ref.namelist() if supports_filename(f)])
except zipfile.BadZipFile:
pass
finally:
for unlink in zip_ref.namelist()[::-1]:
filepath = os.path.join(bucket_dir, unlink)
if os.path.exists(filepath):
if os.path.isdir(filepath):
os.rmdir(filepath)
else:
os.unlink(filepath)
continue
yield f"```{filename}\n"
if has_pypdf2 and filename.endswith(".pdf"):
try:
reader = PyPDF2.PdfReader(file_path)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
yield page.extract_text()
except PdfReadError:
continue
if has_pdfplumber and filename.endswith(".pdf"):
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
yield page.extract_text()
if has_pdfminer and filename.endswith(".pdf"):
yield extract_text(file_path)
elif has_docx and filename.endswith(".docx"):
doc = Document(file_path)
for para in doc.paragraphs:
yield para.text
elif has_docx2txt and filename.endswith(".docx"):
yield docx2txt.process(file_path)
elif has_odfpy and filename.endswith(".odt"):
textdoc = load(file_path)
allparas = textdoc.getElementsByType(P)
for p in allparas:
yield p.firstChild.data if p.firstChild else ""
elif has_ebooklib and filename.endswith(".epub"):
book = epub.read_epub(file_path)
for doc_item in book.get_items():
if doc_item.get_type() == ebooklib.ITEM_DOCUMENT:
yield doc_item.get_content().decode(errors='ignore')
elif has_openpyxl and filename.endswith(".xlsx"):
df = pd.read_excel(file_path)
for row in df.itertuples(index=False):
yield " ".join(str(cell) for cell in row)
elif has_beautifulsoup4 and filename.endswith(".html"):
yield from scrape_text(file_path.read_text(errors="ignore"))
elif extension in PLAIN_FILE_EXTENSIONS:
yield file_path.read_text(errors="ignore")
yield f"\n```\n\n"
def cache_stream(stream: Iterator[str], bucket_dir: Path) -> Iterator[str]:
cache_file = bucket_dir / PLAIN_CACHE
tmp_file = bucket_dir / f"{PLAIN_CACHE}.{time.time()}.tmp"
if cache_file.exists():
for chunk in read_path_chunked(cache_file):
yield chunk
return
with open(tmp_file, "w") as f:
for chunk in stream:
f.write(chunk)
yield chunk
tmp_file.rename(cache_file)
def is_complete(data: str):
return data.endswith("\n```\n\n") and data.count("```") % 2 == 0
def read_path_chunked(path: Path):
with path.open("r", encoding='utf-8') as f:
current_chunk_size = 0
buffer = ""
for line in f:
current_chunk_size += len(line.encode('utf-8'))
buffer += line
if current_chunk_size >= 4096:
if is_complete(buffer) or current_chunk_size >= 8192:
yield buffer
buffer = ""
current_chunk_size = 0
if current_chunk_size > 0:
yield buffer
def read_bucket(bucket_dir: Path):
bucket_dir = Path(bucket_dir)
cache_file = bucket_dir / PLAIN_CACHE
spacy_file = bucket_dir / f"spacy_0001.cache"
if not spacy_file.exists():
yield cache_file.read_text()
for idx in range(1, 1000):
spacy_file = bucket_dir / f"spacy_{idx:04d}.cache"
plain_file = bucket_dir / f"plain_{idx:04d}.cache"
if spacy_file.exists():
yield spacy_file.read_text()
elif plain_file.exists():
yield plain_file.read_text()
else:
break
def stream_read_parts_and_refine(bucket_dir: Path, delete_files: bool = False) -> Iterator[str]:
cache_file = bucket_dir / PLAIN_CACHE
space_file = Path(bucket_dir) / f"spacy_0001.cache"
part_one = bucket_dir / f"plain_0001.cache"
if not space_file.exists() and not part_one.exists() and cache_file.exists():
split_file_by_size_and_newline(cache_file, bucket_dir)
for idx in range(1, 1000):
part = bucket_dir / f"plain_{idx:04d}.cache"
tmp_file = Path(bucket_dir) / f"spacy_{idx:04d}.{time.time()}.tmp"
cache_file = Path(bucket_dir) / f"spacy_{idx:04d}.cache"
if cache_file.exists():
with open(cache_file, "r") as f:
yield f.read()
continue
if not part.exists():
break
with tmp_file.open("w") as f:
for chunk in spacy_refine_chunks(read_path_chunked(part)):
f.write(chunk)
yield chunk
tmp_file.rename(cache_file)
if delete_files:
part.unlink()
def split_file_by_size_and_newline(input_filename, output_dir, chunk_size_bytes=1024*1024): # 1MB
"""Splits a file into chunks of approximately chunk_size_bytes, splitting only at newline characters.
Args:
input_filename: Path to the input file.
output_prefix: Prefix for the output files (e.g., 'output_part_').
chunk_size_bytes: Desired size of each chunk in bytes.
"""
split_filename = os.path.splitext(os.path.basename(input_filename))
output_prefix = os.path.join(output_dir, split_filename[0] + "_")
with open(input_filename, 'r', encoding='utf-8') as infile:
chunk_num = 1
current_chunk = ""
current_chunk_size = 0
for line in infile:
current_chunk += line
current_chunk_size += len(line.encode('utf-8'))
if current_chunk_size >= chunk_size_bytes:
if is_complete(current_chunk) or current_chunk_size >= chunk_size_bytes * 2:
output_filename = f"{output_prefix}{chunk_num:04d}{split_filename[1]}"
with open(output_filename, 'w', encoding='utf-8') as outfile:
outfile.write(current_chunk)
current_chunk = ""
current_chunk_size = 0
chunk_num += 1
# Write the last chunk
if current_chunk:
output_filename = f"{output_prefix}{chunk_num:04d}{split_filename[1]}"
with open(output_filename, 'w', encoding='utf-8') as outfile:
outfile.write(current_chunk)
async def get_filename(response: ClientResponse):
"""
Attempts to extract a filename from an aiohttp response. Prioritizes Content-Disposition, then URL.
Args:
response: The aiohttp ClientResponse object.
Returns:
The filename as a string, or None if it cannot be determined.
"""
content_disposition = response.headers.get('Content-Disposition')
if content_disposition:
try:
filename = content_disposition.split('filename=')[1].strip('"')
if filename:
return secure_filename(filename)
except IndexError:
pass
content_type = response.headers.get('Content-Type')
url = str(response.url)
if content_type and url:
extension = await get_file_extension(response)
if extension:
parsed_url = urllib.parse.urlparse(url)
sha256_hash = hashlib.sha256(url.encode()).digest()
base64_encoded = base64.b32encode(sha256_hash).decode().lower()
return f"{parsed_url.netloc} {parsed_url.path[1:].replace('/', '_')} {base64_encoded[:6]}{extension}"
return None
async def get_file_extension(response: ClientResponse):
"""
Attempts to determine the file extension from an aiohttp response. Improved to handle more types.
Args:
response: The aiohttp ClientResponse object.
Returns:
The file extension (e.g., ".html", ".json", ".pdf", ".zip", ".md", ".txt") as a string,
or None if it cannot be determined.
"""
content_type = response.headers.get('Content-Type')
if content_type:
if "html" in content_type.lower():
return ".html"
elif "json" in content_type.lower():
return ".json"
elif "pdf" in content_type.lower():
return ".pdf"
elif "zip" in content_type.lower():
return ".zip"
elif "text/plain" in content_type.lower():
return ".txt"
elif "markdown" in content_type.lower():
return ".md"
url = str(response.url)
if url:
return Path(url).suffix.lower()
return None
def read_links(html: str, base: str) -> set[str]:
soup = BeautifulSoup(html, "html.parser")
for selector in [
"main",
".main-content-wrapper",
".main-content",
".emt-container-inner",
".content-wrapper",
"#content",
"#mainContent",
]:
select = soup.select_one(selector)
if select:
soup = select
break
urls = []
for link in soup.select("a"):
if "rel" not in link.attrs or "nofollow" not in link.attrs["rel"]:
url = link.attrs.get("href")
if url and url.startswith("https://"):
urls.append(url.split("#")[0])
return set([urllib.parse.urljoin(base, link) for link in urls])
async def download_urls(
bucket_dir: Path,
urls: list[str],
max_depth: int = 2,
loaded_urls: set[str] = set(),
lock: asyncio.Lock = None,
delay: int = 3,
group_size: int = 5,
timeout: int = 10,
proxy: Optional[str] = None
) -> list[str]:
if lock is None:
lock = asyncio.Lock()
async with ClientSession(
connector=get_connector(proxy=proxy),
timeout=ClientTimeout(timeout)
) as session:
async def download_url(url: str) -> str:
try:
async with session.get(url) as response:
response.raise_for_status()
filename = await get_filename(response)
if not filename:
print(f"Failed to get filename for {url}")
return None
newfiles = [filename]
if filename.endswith(".html") and max_depth > 0:
new_urls = read_links(await response.text(), str(response.url))
async with lock:
new_urls = [new_url for new_url in new_urls if new_url not in loaded_urls]
[loaded_urls.add(url) for url in new_urls]
if new_urls:
for i in range(0, len(new_urls), group_size):
newfiles += await download_urls(bucket_dir, new_urls[i:i + group_size], max_depth - 1, loaded_urls, lock, delay + 1)
await asyncio.sleep(delay)
if supports_filename(filename) and filename != DOWNLOADS_FILE:
target = bucket_dir / filename
with target.open("wb") as f:
async for chunk in response.content.iter_chunked(4096):
f.write(chunk)
return newfiles
except (ClientError, asyncio.TimeoutError) as e:
debug.log(f"Download failed: {e.__class__.__name__}: {e}")
return None
files = set()
for results in await asyncio.gather(*[download_url(url) for url in urls]):
if results:
[files.add(url) for url in results]
return files
def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]:
bucket_dir = Path(bucket_dir)
bucket_dir.mkdir(parents=True, exist_ok=True)
try:
download_file = bucket_dir / DOWNLOADS_FILE
if download_file.exists():
urls = []
with download_file.open('r') as f:
data = json.load(f)
download_file.unlink()
if isinstance(data, list):
for item in data:
if "url" in item:
urls.append(item["url"])
if urls:
filenames = asyncio.run(download_urls(bucket_dir, urls))
with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f:
[f.write(f"{filename}\n") for filename in filenames if filename]
if refine_chunks_with_spacy:
size = 0
for chunk in stream_read_parts_and_refine(bucket_dir, delete_files):
if event_stream:
size += len(chunk)
yield f'data: {json.dumps({"action": "refine", "size": size})}\n\n'
else:
yield chunk
else:
streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir))
streaming = cache_stream(streaming, bucket_dir)
size = 0
for chunk in streaming:
if event_stream:
size += len(chunk)
yield f'data: {json.dumps({"action": "load", "size": size})}\n\n'
else:
yield chunk
files_txt = os.path.join(bucket_dir, FILE_LIST)
if delete_files and os.path.exists(files_txt):
for filename in get_filenames(bucket_dir):
if os.path.exists(os.path.join(bucket_dir, filename)):
os.remove(os.path.join(bucket_dir, filename))
os.remove(files_txt)
if event_stream:
yield f'data: {json.dumps({"action": "delete_files"})}\n\n'
if event_stream:
yield f'data: {json.dumps({"action": "done"})}\n\n'
except Exception as e:
if event_stream:
yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n'
raise e

87
g4f/tools/run_tools.py Normal file
View file

@ -0,0 +1,87 @@
from __future__ import annotations
import re
import json
import asyncio
from typing import Optional, Callable, AsyncIterator
from ..typing import Messages
from ..providers.helper import filter_none
from ..client.helper import to_async_iterator
from .web_search import do_search, get_search_message
from .files import read_bucket, get_bucket_dir
from .. import debug
def validate_arguments(data: dict) -> dict:
if "arguments" in data:
if isinstance(data["arguments"], str):
data["arguments"] = json.loads(data["arguments"])
if not isinstance(data["arguments"], dict):
raise ValueError("Tool function arguments must be a dictionary or a json string")
else:
return filter_none(**data["arguments"])
else:
return {}
async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
if tool_calls is not None:
for tool in tool_calls:
if tool.get("type") == "function":
if tool.get("function", {}).get("name") == "search_tool":
tool["function"]["arguments"] = validate_arguments(tool["function"])
messages = messages.copy()
messages[-1]["content"] = await do_search(
messages[-1]["content"],
**tool["function"]["arguments"]
)
elif tool.get("function", {}).get("name") == "continue":
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue writing the story after this line start with a plus sign if you begin a new word.\n{last_line}"
messages.append({"role": "user", "content": content})
response = async_iter_callback(model=model, messages=messages, **kwargs)
if not hasattr(response, "__aiter__"):
response = to_async_iterator(response)
async for chunk in response:
yield chunk
def iter_run_tools(
iter_callback: Callable,
model: str,
messages: Messages,
provider: Optional[str] = None,
tool_calls: Optional[list] = None,
**kwargs
) -> AsyncIterator:
if tool_calls is not None:
for tool in tool_calls:
if tool.get("type") == "function":
if tool.get("function", {}).get("name") == "search_tool":
tool["function"]["arguments"] = validate_arguments(tool["function"])
messages[-1]["content"] = get_search_message(
messages[-1]["content"],
raise_search_exceptions=True,
**tool["function"]["arguments"]
)
elif tool.get("function", {}).get("name") == "safe_search_tool":
tool["function"]["arguments"] = validate_arguments(tool["function"])
try:
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], **tool["function"]["arguments"]))
except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Enable provider native web search
kwargs["web_search"] = True
elif tool.get("function", {}).get("name") == "continue_tool":
if provider not in ("OpenaiAccount", "HuggingFace"):
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"continue after this line:\n{last_line}"
messages.append({"role": "user", "content": content})
else:
# Enable provider native continue
if "action" not in kwargs:
kwargs["action"] = "continue"
elif tool.get("function", {}).get("name") == "bucket_tool":
def on_bucket(match):
return "".join(read_bucket(get_bucket_dir(match.group(1))))
messages[-1]["content"] = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, messages[-1]["content"])
print(messages[-1])
return iter_callback(model=model, messages=messages, provider=provider, **kwargs)

View file

@ -1,6 +1,10 @@
from __future__ import annotations
from aiohttp import ClientSession, ClientTimeout, ClientError
import json
import hashlib
from pathlib import Path
from collections import Counter
try:
from duckduckgo_search import DDGS
from duckduckgo_search.exceptions import DuckDuckGoSearchException
@ -8,8 +12,15 @@ try:
has_requirements = True
except ImportError:
has_requirements = False
from .errors import MissingRequirementsError
from . import debug
try:
import spacy
has_spacy = True
except:
has_spacy = False
from typing import Iterator
from ..cookies import get_cookies_dir
from ..errors import MissingRequirementsError
from .. import debug
import asyncio
@ -52,7 +63,7 @@ class SearchResultEntry():
def set_text(self, text: str):
self.text = text
def scrape_text(html: str, max_words: int = None) -> str:
def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
soup = BeautifulSoup(html, "html.parser")
for selector in [
"main",
@ -72,14 +83,10 @@ def scrape_text(html: str, max_words: int = None) -> str:
select = soup.select_one(remove)
if select:
select.extract()
clean_text = ""
for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"):
text = paragraph.get_text()
for line in text.splitlines():
words = []
for word in line.replace("\t", " ").split(" "):
if word:
words.append(word)
for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"):
for line in paragraph.text.splitlines():
words = [word for word in line.replace("\t", " ").split(" ") if word]
count = len(words)
if not count:
continue
@ -87,18 +94,23 @@ def scrape_text(html: str, max_words: int = None) -> str:
max_words -= count
if max_words <= 0:
break
if clean_text:
clean_text += "\n"
clean_text += " ".join(words)
return clean_text
yield " ".join(words) + "\n"
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
try:
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
bucket_dir.mkdir(parents=True, exist_ok=True)
md5_hash = hashlib.md5(url.encode()).hexdigest()
cache_file = bucket_dir / f"{url.split('/')[3]}.{md5_hash}.txt"
if cache_file.exists():
return cache_file.read_text()
async with session.get(url) as response:
if response.status == 200:
html = await response.text()
return scrape_text(html, max_words)
text = "".join(scrape_text(html, max_words))
with open(cache_file, "w") as f:
f.write(text)
return text
except ClientError:
return
@ -113,7 +125,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
safesearch="moderate",
timelimit="y",
max_results=max_results,
backend=backend, # Changed from 'api' to 'auto'
backend=backend,
):
results.append(SearchResultEntry(
result["title"],
@ -149,8 +161,20 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
if query is None:
query = prompt
search_results = await search(query, **kwargs)
query = spacy_get_keywords(prompt)
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode()
md5_hash = hashlib.md5(json_bytes).hexdigest()
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search"
bucket_dir.mkdir(parents=True, exist_ok=True)
cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt"
if cache_file.exists():
with open(cache_file, "r") as f:
search_results = f.read()
else:
search_results = await search(query, **kwargs)
with open(cache_file, "w") as f:
f.write(str(search_results))
new_prompt = f"""
{search_results}
@ -170,3 +194,37 @@ def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) ->
raise e
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
return prompt
def spacy_get_keywords(text: str):
if not has_spacy:
return text
# Load the spaCy language model
nlp = spacy.load("en_core_web_sm")
# Process the query
doc = nlp(text)
# Extract keywords based on POS and named entities
keywords = []
for token in doc:
# Filter for nouns, proper nouns, and adjectives
if token.pos_ in {"NOUN", "PROPN", "ADJ"} and not token.is_stop:
keywords.append(token.lemma_)
# Add named entities as keywords
for ent in doc.ents:
keywords.append(ent.text)
# Remove duplicates and print keywords
keywords = list(set(keywords))
#print("Keyword:", keywords)
#keyword_freq = Counter(keywords)
#keywords = keyword_freq.most_common()
#print("Keyword Frequencies:", keywords)
keywords = [chunk.text for chunk in doc.noun_chunks if not chunk.root.is_stop]
#print("Phrases:", keywords)
return keywords

View file

@ -1,257 +0,0 @@
from __future__ import annotations
try:
from platformdirs import user_config_dir
from undetected_chromedriver import Chrome, ChromeOptions, find_chrome_executable
from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.remote.webelement import WebElement
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.keys import Keys
from selenium.common.exceptions import NoSuchElementException
has_requirements = True
except ImportError:
from typing import Type as WebDriver
has_requirements = False
import time
from shutil import which
from os import path
from os import access, R_OK
from .typing import Cookies
from .errors import MissingRequirementsError
from . import debug
try:
from pyvirtualdisplay import Display
has_pyvirtualdisplay = True
except ImportError:
has_pyvirtualdisplay = False
try:
from undetected_chromedriver import Chrome as _Chrome, ChromeOptions
from seleniumwire.webdriver import InspectRequestsMixin, DriverCommonMixin
class Chrome(InspectRequestsMixin, DriverCommonMixin, _Chrome):
def __init__(self, *args, options=None, seleniumwire_options={}, **kwargs):
if options is None:
options = ChromeOptions()
config = self._setup_backend(seleniumwire_options)
options.add_argument(f"--proxy-server={config['proxy']['httpProxy']}")
options.add_argument("--proxy-bypass-list=<-loopback>")
options.add_argument("--ignore-certificate-errors")
super().__init__(*args, options=options, **kwargs)
has_seleniumwire = True
except:
has_seleniumwire = False
def get_browser(
user_data_dir: str = None,
headless: bool = False,
proxy: str = None,
options: ChromeOptions = None
) -> WebDriver:
"""
Creates and returns a Chrome WebDriver with specified options.
Args:
user_data_dir (str, optional): Directory for user data. If None, uses default directory.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
proxy (str, optional): Proxy settings for the browser. Defaults to None.
options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
Returns:
WebDriver: An instance of WebDriver configured with the specified options.
"""
if not has_requirements:
raise MissingRequirementsError('Install Webdriver packages | pip install -U g4f[webdriver]')
browser = find_chrome_executable()
if browser is None:
raise MissingRequirementsError('Install "Google Chrome" browser')
if user_data_dir is None:
user_data_dir = user_config_dir("g4f")
if user_data_dir and debug.logging:
print("Open browser with config dir:", user_data_dir)
if not options:
options = ChromeOptions()
if proxy:
options.add_argument(f'--proxy-server={proxy}')
# Check for system driver in docker
driver = which('chromedriver') or '/usr/bin/chromedriver'
if not path.isfile(driver) or not access(driver, R_OK):
driver = None
return Chrome(
options=options,
user_data_dir=user_data_dir,
driver_executable_path=driver,
browser_executable_path=browser,
headless=headless,
patcher_force_close=True
)
def get_driver_cookies(driver: WebDriver) -> Cookies:
"""
Retrieves cookies from the specified WebDriver.
Args:
driver (WebDriver): The WebDriver instance from which to retrieve cookies.
Returns:
dict: A dictionary containing cookies with their names as keys and values as cookie values.
"""
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
"""
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
Args:
driver (WebDriver): The WebDriver to use for accessing the URL.
url (str): The URL to access.
timeout (int): Time in seconds to wait for the page to load.
Raises:
Exception: If there is an error while bypassing Cloudflare or loading the page.
"""
driver.get(url)
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
if debug.logging:
print("Cloudflare protection detected:", url)
# Open website in a new tab
element = driver.find_element(By.ID, "challenge-body-text")
driver.execute_script(f"""
arguments[0].addEventListener('click', () => {{
window.open(arguments[1]);
}});
""", element, url)
element.click()
time.sleep(5)
# Switch to the new tab and close the old tab
original_window = driver.current_window_handle
for window_handle in driver.window_handles:
if window_handle != original_window:
driver.close()
driver.switch_to.window(window_handle)
break
# Click on the challenge button in the iframe
try:
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
).click()
except NoSuchElementException:
...
except Exception as e:
if debug.logging:
print(f"Error bypassing Cloudflare: {str(e).splitlines()[0]}")
#driver.switch_to.default_content()
driver.switch_to.window(window_handle)
driver.execute_script("document.href = document.href;")
WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
)
class WebDriverSession:
"""
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
"""
def __init__(
self,
webdriver: WebDriver = None,
user_data_dir: str = None,
headless: bool = False,
virtual_display: bool = False,
proxy: str = None,
options: ChromeOptions = None
):
"""
Initializes a new instance of the WebDriverSession.
Args:
webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
user_data_dir (str, optional): Directory for user data. Defaults to None.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
proxy (str, optional): Proxy settings for the browser. Defaults to None.
options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
"""
self.webdriver = webdriver
self.user_data_dir = user_data_dir
self.headless = headless
self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
self.proxy = proxy
self.options = options
self.default_driver = None
def reopen(
self,
user_data_dir: str = None,
headless: bool = False,
virtual_display: bool = False
) -> WebDriver:
"""
Reopens the WebDriver session with new settings.
Args:
user_data_dir (str, optional): Directory for user data. Defaults to current value.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
Returns:
WebDriver: The reopened WebDriver instance.
"""
user_data_dir = user_data_dir or self.user_data_dir
if self.default_driver:
self.default_driver.quit()
if not virtual_display and self.virtual_display:
self.virtual_display.stop()
self.virtual_display = None
self.default_driver = get_browser(user_data_dir, headless, self.proxy)
return self.default_driver
def __enter__(self) -> WebDriver:
"""
Context management method for entering a session. Initializes and returns a WebDriver instance.
Returns:
WebDriver: An instance of WebDriver for this session.
"""
if self.webdriver:
return self.webdriver
if self.virtual_display:
self.virtual_display.start()
self.default_driver = get_browser(self.user_data_dir, self.headless, self.proxy, self.options)
return self.default_driver
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context management method for exiting a session. Closes and quits the WebDriver.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Note:
Closes the WebDriver and stops the virtual display if used.
"""
if self.default_driver:
try:
self.default_driver.close()
except Exception as e:
if debug.logging:
print(f"Error closing WebDriver: {str(e).splitlines()[0]}")
finally:
self.default_driver.quit()
if self.virtual_display:
self.virtual_display.stop()
def element_send_text(element: WebElement, text: str) -> None:
script = "arguments[0].innerText = arguments[1];"
element.parent.execute_script(script, element, text)
element.send_keys(Keys.ENTER)

View file

@ -84,6 +84,16 @@ EXTRA_REQUIRE = {
],
"local": [
"gpt4all"
],
"files": [
"spacy",
"filesplit",
"beautifulsoup4",
"pypdf2",
"docx",
"odfpy",
"ebooklib",
"openpyxl",
]
}