Change doctypes style to Google

Fix typo in latest_version
Fix Phind Provider
Add unittest worklow and main tests
This commit is contained in:
Heiner Lohaus 2024-01-14 15:04:37 +01:00
parent 5756586cde
commit 32252def15
10 changed files with 477 additions and 103 deletions

19
.github/workflows/unittest.yml vendored Normal file
View file

@ -0,0 +1,19 @@
name: Unittest
on: [push]
jobs:
build:
name: Build unittest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
cache: 'pip'
- name: Install requirements
- run: pip install -r requirements.txt
- name: Run tests
run: python -m etc.unittest.main

73
etc/unittest/main.py Normal file
View file

@ -0,0 +1,73 @@
import sys
import pathlib
import unittest
from unittest.mock import MagicMock
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
import g4f
from g4f import ChatCompletion, get_last_provider
from g4f.gui.server.backend import Backend_Api, get_error_message
from g4f.base_provider import BaseProvider
g4f.debug.logging = False
class MockProvider(BaseProvider):
working = True
def create_completion(
model, messages, stream, **kwargs
):
yield "Mock"
async def create_async(
model, messages, **kwargs
):
return "Mock"
class TestBackendApi(unittest.TestCase):
def setUp(self):
self.app = MagicMock()
self.api = Backend_Api(self.app)
def test_version(self):
response = self.api.get_version()
self.assertIn("version", response)
self.assertIn("latest_version", response)
class TestChatCompletion(unittest.TestCase):
def test_create(self):
messages = [{'role': 'user', 'content': 'Hello'}]
result = ChatCompletion.create(g4f.models.default, messages)
self.assertTrue("Hello" in result or "Good" in result)
def test_get_last_provider(self):
messages = [{'role': 'user', 'content': 'Hello'}]
ChatCompletion.create(g4f.models.default, messages, MockProvider)
self.assertEqual(get_last_provider(), MockProvider)
def test_bing_provider(self):
messages = [{'role': 'user', 'content': 'Hello'}]
provider = g4f.Provider.Bing
result = ChatCompletion.create(g4f.models.default, messages, provider)
self.assertTrue("Bing" in result)
class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
async def test_async(self):
messages = [{'role': 'user', 'content': 'Hello'}]
result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
self.assertTrue("Mock" in result)
class TestUtilityFunctions(unittest.TestCase):
def test_get_error_message(self):
g4f.debug.last_provider = g4f.Provider.Bing
exception = Exception("Message")
result = get_error_message(exception)
self.assertEqual("Bing: Exception: Message", result)
if __name__ == '__main__':
unittest.main()

View file

@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider):
"rewrittenQuestion": prompt, "rewrittenQuestion": prompt,
"challenge": 0.21132115912208504 "challenge": 0.21132115912208504
} }
async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response: async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response:
new_line = False new_line = False
async for line in response.iter_lines(): async for line in response.iter_lines():
if line.startswith(b"data: "): if line.startswith(b"data: "):
chunk = line[6:] chunk = line[6:]
if chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"): if chunk.startswith(b'<PHIND_DONE/>'):
break
if chunk.startswith(b'<PHIND_WEBRESULTS>') or chunk.startswith(b'<PHIND_FOLLOWUP>'):
pass
elif chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
pass pass
elif chunk: elif chunk:
yield chunk.decode() yield chunk.decode()

View file

@ -36,6 +36,17 @@ class AbstractProvider(BaseProvider):
) -> str: ) -> str:
""" """
Asynchronously creates a result based on the given model and messages. Asynchronously creates a result based on the given model and messages.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
""" """
loop = loop or get_event_loop() loop = loop or get_event_loop()
@ -52,6 +63,12 @@ class AbstractProvider(BaseProvider):
def params(cls) -> str: def params(cls) -> str:
""" """
Returns the parameters supported by the provider. Returns the parameters supported by the provider.
Args:
cls (type): The class on which this property is called.
Returns:
str: A string listing the supported parameters.
""" """
sig = signature( sig = signature(
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
@ -90,6 +107,17 @@ class AsyncProvider(AbstractProvider):
) -> CreateResult: ) -> CreateResult:
""" """
Creates a completion result synchronously. Creates a completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to False.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
""" """
loop = loop or get_event_loop() loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs) coro = cls.create_async(model, messages, **kwargs)
@ -104,6 +132,17 @@ class AsyncProvider(AbstractProvider):
) -> str: ) -> str:
""" """
Abstract method for creating asynchronous results. Abstract method for creating asynchronous results.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Raises:
NotImplementedError: If this method is not overridden in derived classes.
Returns:
str: The created result as a string.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -126,6 +165,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> CreateResult: ) -> CreateResult:
""" """
Creates a streaming completion result synchronously. Creates a streaming completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to True.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the streaming completion creation.
""" """
loop = loop or get_event_loop() loop = loop or get_event_loop()
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
@ -146,6 +196,15 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> str: ) -> str:
""" """
Asynchronously creates a result from a generator. Asynchronously creates a result from a generator.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
""" """
return "".join([ return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
@ -162,5 +221,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncResult: ) -> AsyncResult:
""" """
Abstract method for creating an asynchronous generator. Abstract method for creating an asynchronous generator.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to True.
**kwargs: Additional keyword arguments.
Raises:
NotImplementedError: If this method is not overridden in derived classes.
Returns:
AsyncResult: An asynchronous generator yielding results.
""" """
raise NotImplementedError() raise NotImplementedError()

View file

@ -198,7 +198,7 @@ class CreateImagesBing:
_cookies: Dict[str, str] = {} _cookies: Dict[str, str] = {}
@classmethod @classmethod
def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str]: def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]:
""" """
Generator for creating imagecompletion based on a prompt. Generator for creating imagecompletion based on a prompt.

View file

@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType
system_message = """ system_message = """
You can generate custom images with the DALL-E 3 image generator. You can generate custom images with the DALL-E 3 image generator.
To generate a image with a prompt, do this: To generate an image with a prompt, do this:
<img data-prompt=\"keywords for the image\"> <img data-prompt=\"keywords for the image\">
Don't use images with data uri. It is important to use a prompt instead. Don't use images with data uri. It is important to use a prompt instead.
<img data-prompt=\"image caption\"> <img data-prompt=\"image caption\">
""" """
class CreateImagesProvider(BaseProvider): class CreateImagesProvider(BaseProvider):
"""
Provider class for creating images based on text prompts.
This provider handles image creation requests embedded within message content,
using provided image creation functions.
Attributes:
provider (ProviderType): The underlying provider to handle non-image related tasks.
create_images (callable): A function to create images synchronously.
create_images_async (callable): A function to create images asynchronously.
system_message (str): A message that explains the image creation capability.
include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
__name__ (str): Name of the provider.
url (str): URL of the provider.
working (bool): Indicates if the provider is operational.
supports_stream (bool): Indicates if the provider supports streaming.
"""
def __init__( def __init__(
self, self,
provider: ProviderType, provider: ProviderType,
@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider):
system_message: str = system_message, system_message: str = system_message,
include_placeholder: bool = True include_placeholder: bool = True
) -> None: ) -> None:
"""
Initializes the CreateImagesProvider.
Args:
provider (ProviderType): The underlying provider.
create_images (callable): Function to create images synchronously.
create_async (callable): Function to create images asynchronously.
system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
"""
self.provider = provider self.provider = provider
self.create_images = create_images self.create_images = create_images
self.create_images_async = create_async self.create_images_async = create_async
@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider):
stream: bool = False, stream: bool = False,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
"""
Creates a completion result, processing any image creation prompts found within the messages.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process, which may contain image prompts.
stream (bool, optional): Indicates whether to stream the results. Defaults to False.
**kwargs: Additional keywordarguments for the provider.
Yields:
CreateResult: Yields chunks of the processed messages, including image data if applicable.
Note:
This method processes messages to detect image creation prompts. When such a prompt is found,
it calls the synchronous image creation function and includes the resulting image in the output.
"""
messages.insert(0, {"role": "system", "content": self.system_message}) messages.insert(0, {"role": "system", "content": self.system_message})
buffer = "" buffer = ""
for chunk in self.provider.create_completion(model, messages, stream, **kwargs): for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> str: ) -> str:
"""
Asynchronously creates a response, processing any image creation prompts found within the messages.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process, which may contain image prompts.
**kwargs: Additional keyword arguments for the provider.
Returns:
str: The processed response string, including asynchronously generated image data if applicable.
Note:
This method processes messages to detect image creation prompts. When such a prompt is found,
it calls the asynchronous image creation function and includes the resulting image in the output.
"""
messages.insert(0, {"role": "system", "content": self.system_message}) messages.insert(0, {"role": "system", "content": self.system_message})
response = await self.provider.create_async(model, messages, **kwargs) response = await self.provider.create_async(model, messages, **kwargs)
matches = re.findall(r'(<img data-prompt="(.*?)">)', response) matches = re.findall(r'(<img data-prompt="(.*?)">)', response)

View file

@ -652,9 +652,9 @@ observer.observe(message_input, { attributes: true });
document.title = 'g4f - gui - ' + versions["version"]; document.title = 'g4f - gui - ' + versions["version"];
text = "version ~ " text = "version ~ "
if (versions["version"] != versions["lastet_version"]) { if (versions["version"] != versions["latest_version"]) {
release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"]; release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["lastet_version"] +'">' + versions["version"] + ' 🆕</a>'; text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["latest_version"] +'">' + versions["version"] + ' 🆕</a>';
} else { } else {
text += versions["version"]; text += versions["version"];
} }

View file

@ -1,6 +1,7 @@
import logging import logging
import json import json
from flask import request, Flask from flask import request, Flask
from typing import Generator
from g4f import debug, version, models from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image from g4f.image import is_allowed_extension, to_image
@ -11,60 +12,123 @@ from .internet import get_search_message
debug.logging = True debug.logging = True
class Backend_Api: class Backend_Api:
"""
Handles various endpoints in a Flask application for backend operations.
This class provides methods to interact with models, providers, and to handle
various functionalities like conversations, error handling, and version management.
Attributes:
app (Flask): A Flask application instance.
routes (dict): A dictionary mapping API endpoints to their respective handlers.
"""
def __init__(self, app: Flask) -> None: def __init__(self, app: Flask) -> None:
"""
Initialize the backend API with the given Flask application.
Args:
app (Flask): Flask application instance to attach routes to.
"""
self.app: Flask = app self.app: Flask = app
self.routes = { self.routes = {
'/backend-api/v2/models': { '/backend-api/v2/models': {
'function': self.models, 'function': self.get_models,
'methods' : ['GET'] 'methods': ['GET']
}, },
'/backend-api/v2/providers': { '/backend-api/v2/providers': {
'function': self.providers, 'function': self.get_providers,
'methods' : ['GET'] 'methods': ['GET']
}, },
'/backend-api/v2/version': { '/backend-api/v2/version': {
'function': self.version, 'function': self.get_version,
'methods' : ['GET'] 'methods': ['GET']
}, },
'/backend-api/v2/conversation': { '/backend-api/v2/conversation': {
'function': self._conversation, 'function': self.handle_conversation,
'methods': ['POST'] 'methods': ['POST']
}, },
'/backend-api/v2/gen.set.summarize:title': { '/backend-api/v2/gen.set.summarize:title': {
'function': self._gen_title, 'function': self.generate_title,
'methods': ['POST'] 'methods': ['POST']
}, },
'/backend-api/v2/error': { '/backend-api/v2/error': {
'function': self.error, 'function': self.handle_error,
'methods': ['POST'] 'methods': ['POST']
} }
} }
def error(self): def handle_error(self):
"""
Initialize the backend API with the given Flask application.
Args:
app (Flask): Flask application instance to attach routes to.
"""
print(request.json) print(request.json)
return 'ok', 200 return 'ok', 200
def models(self): def get_models(self):
"""
Return a list of all models.
Fetches and returns a list of all available models in the system.
Returns:
List[str]: A list of model names.
"""
return _all_models return _all_models
def providers(self): def get_providers(self):
return [ """
provider.__name__ for provider in __providers__ if provider.working Return a list of all working providers.
] """
return [provider.__name__ for provider in __providers__ if provider.working]
def version(self): def get_version(self):
"""
Returns the current and latest version of the application.
Returns:
dict: A dictionary containing the current and latest version.
"""
return { return {
"version": version.utils.current_version, "version": version.utils.current_version,
"lastet_version": version.get_latest_version(), "latest_version": version.get_latest_version(),
} }
def _gen_title(self): def generate_title(self):
return { """
'title': '' Generates and returns a title based on the request data.
}
Returns:
dict: A dictionary with the generated title.
"""
return {'title': ''}
def _conversation(self): def handle_conversation(self):
"""
Handles conversation requests and streams responses back.
Returns:
Response: A Flask response object for streaming.
"""
kwargs = self._prepare_conversation_kwargs()
return self.app.response_class(
self._create_response_stream(kwargs),
mimetype='text/event-stream'
)
def _prepare_conversation_kwargs(self):
"""
Prepares arguments for chat completion based on the request data.
Reads the request and prepares the necessary arguments for handling
a chat completion request.
Returns:
dict: Arguments prepared for chat completion.
"""
kwargs = {} kwargs = {}
if 'image' in request.files: if 'image' in request.files:
file = request.files['image'] file = request.files['image']
@ -87,47 +151,70 @@ class Backend_Api:
messages[-1]["content"] = get_search_message(messages[-1]["content"]) messages[-1]["content"] = get_search_message(messages[-1]["content"])
model = json_data.get('model') model = json_data.get('model')
model = model if model else models.default model = model if model else models.default
provider = json_data.get('provider', '').replace('g4f.Provider.', '')
provider = provider if provider and provider != "Auto" else None
patch = patch_provider if json_data.get('patch_provider') else None patch = patch_provider if json_data.get('patch_provider') else None
def try_response(): return {
try: "model": model,
first = True "provider": provider,
for chunk in ChatCompletion.create( "messages": messages,
model=model, "stream": True,
provider=provider, "ignore_stream_and_auth": True,
messages=messages, "patch_provider": patch,
stream=True, **kwargs
ignore_stream_and_auth=True, }
patch_provider=patch,
**kwargs
):
if first:
first = False
yield json.dumps({
'type' : 'provider',
'provider': get_last_provider(True)
}) + "\n"
if isinstance(chunk, Exception):
logging.exception(chunk)
yield json.dumps({
'type' : 'message',
'message': get_error_message(chunk),
}) + "\n"
else:
yield json.dumps({
'type' : 'content',
'content': str(chunk),
}) + "\n"
except Exception as e:
logging.exception(e)
yield json.dumps({
'type' : 'error',
'error': get_error_message(e)
})
return self.app.response_class(try_response(), mimetype='text/event-stream') def _create_response_stream(self, kwargs) -> Generator[str, None, None]:
"""
Creates and returns a streaming response for the conversation.
Args:
kwargs (dict): Arguments for creating the chat completion.
Yields:
str: JSON formatted response chunks for the stream.
Raises:
Exception: If an error occurs during the streaming process.
"""
try:
first = True
for chunk in ChatCompletion.create(**kwargs):
if first:
first = False
yield self._format_json('provider', get_last_provider(True))
if isinstance(chunk, Exception):
logging.exception(chunk)
yield self._format_json('message', get_error_message(chunk))
else:
yield self._format_json('content', str(chunk))
except Exception as e:
logging.exception(e)
yield self._format_json('error', get_error_message(e))
def _format_json(self, response_type: str, content) -> str:
"""
Formats and returns a JSON response.
Args:
response_type (str): The type of the response.
content: The content to be included in the response.
Returns:
str: A JSON formatted string.
"""
return json.dumps({
'type': response_type,
response_type: content
}) + "\n"
def get_error_message(exception: Exception) -> str: def get_error_message(exception: Exception) -> str:
"""
Generates a formatted error message from an exception.
Args:
exception (Exception): The exception to format.
Returns:
str: A formatted error message string.
"""
return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}" return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}"

View file

@ -7,10 +7,16 @@ from .errors import VersionNotFoundError
def get_pypi_version(package_name: str) -> str: def get_pypi_version(package_name: str) -> str:
""" """
Get the latest version of a package from PyPI. Retrieves the latest version of a package from PyPI.
:param package_name: The name of the package. Args:
:return: The latest version of the package as a string. package_name (str): The name of the package for which to retrieve the version.
Returns:
str: The latest version of the specified package from PyPI.
Raises:
VersionNotFoundError: If there is an error in fetching the version from PyPI.
""" """
try: try:
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json() response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
@ -20,10 +26,16 @@ def get_pypi_version(package_name: str) -> str:
def get_github_version(repo: str) -> str: def get_github_version(repo: str) -> str:
""" """
Get the latest release version from a GitHub repository. Retrieves the latest release version from a GitHub repository.
:param repo: The name of the GitHub repository. Args:
:return: The latest release version as a string. repo (str): The name of the GitHub repository.
Returns:
str: The latest release version from the specified GitHub repository.
Raises:
VersionNotFoundError: If there is an error in fetching the version from GitHub.
""" """
try: try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json() response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
@ -31,11 +43,16 @@ def get_github_version(repo: str) -> str:
except requests.RequestException as e: except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}") raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
def get_latest_version(): def get_latest_version() -> str:
""" """
Get the latest release version from PyPI or the GitHub repository. Retrieves the latest release version of the 'g4f' package from PyPI or GitHub.
:return: The latest release version as a string. Returns:
str: The latest release version of 'g4f'.
Note:
The function first tries to fetch the version from PyPI. If the package is not found,
it retrieves the version from the GitHub repository.
""" """
try: try:
# Is installed via package manager? # Is installed via package manager?
@ -47,14 +64,19 @@ def get_latest_version():
class VersionUtils: class VersionUtils:
""" """
Utility class for managing and comparing package versions. Utility class for managing and comparing package versions of 'g4f'.
""" """
@cached_property @cached_property
def current_version(self) -> str: def current_version(self) -> str:
""" """
Get the current version of the g4f package. Retrieves the current version of the 'g4f' package.
:return: The current version as a string. Returns:
str: The current version of 'g4f'.
Raises:
VersionNotFoundError: If the version cannot be determined from the package manager,
Docker environment, or git repository.
""" """
# Read from package manager # Read from package manager
try: try:
@ -79,15 +101,19 @@ class VersionUtils:
@cached_property @cached_property
def latest_version(self) -> str: def latest_version(self) -> str:
""" """
Get the latest version of the g4f package. Retrieves the latest version of the 'g4f' package.
:return: The latest version as a string. Returns:
str: The latest version of 'g4f'.
""" """
return get_latest_version() return get_latest_version()
def check_version(self) -> None: def check_version(self) -> None:
""" """
Check if the current version is up to date with the latest version. Checks if the current version of 'g4f' is up to date with the latest version.
Note:
If a newer version is available, it prints a message with the new version and update instructions.
""" """
try: try:
if self.current_version != self.latest_version: if self.current_version != self.latest_version:

View file

@ -21,13 +21,16 @@ def get_browser(
options: ChromeOptions = None options: ChromeOptions = None
) -> WebDriver: ) -> WebDriver:
""" """
Creates and returns a Chrome WebDriver with the specified options. Creates and returns a Chrome WebDriver with specified options.
:param user_data_dir: Directory for user data. If None, uses default directory. Args:
:param headless: Boolean indicating whether to run the browser in headless mode. user_data_dir (str, optional): Directory for user data. If None, uses default directory.
:param proxy: Proxy settings for the browser. headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
:param options: ChromeOptions object with specific browser options. proxy (str, optional): Proxy settings for the browser. Defaults to None.
:return: An instance of WebDriver. options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
Returns:
WebDriver: An instance of WebDriver configured with the specified options.
""" """
if user_data_dir is None: if user_data_dir is None:
user_data_dir = user_config_dir("g4f") user_data_dir = user_config_dir("g4f")
@ -49,10 +52,13 @@ def get_browser(
def get_driver_cookies(driver: WebDriver) -> dict: def get_driver_cookies(driver: WebDriver) -> dict:
""" """
Retrieves cookies from the given WebDriver. Retrieves cookies from the specified WebDriver.
:param driver: WebDriver from which to retrieve cookies. Args:
:return: A dictionary of cookies. driver (WebDriver): The WebDriver instance from which to retrieve cookies.
Returns:
dict: A dictionary containing cookies with their names as keys and values as cookie values.
""" """
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()} return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
@ -60,9 +66,13 @@ def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
""" """
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver. Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
:param driver: The WebDriver to use. Args:
:param url: URL to access. driver (WebDriver): The WebDriver to use for accessing the URL.
:param timeout: Time in seconds to wait for the page to load. url (str): The URL to access.
timeout (int): Time in seconds to wait for the page to load.
Raises:
Exception: If there is an error while bypassing Cloudflare or loading the page.
""" """
driver.get(url) driver.get(url)
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js": if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
@ -86,6 +96,7 @@ class WebDriverSession:
""" """
Manages a Selenium WebDriver session, including handling of virtual displays and proxies. Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
""" """
def __init__( def __init__(
self, self,
webdriver: WebDriver = None, webdriver: WebDriver = None,
@ -95,6 +106,17 @@ class WebDriverSession:
proxy: str = None, proxy: str = None,
options: ChromeOptions = None options: ChromeOptions = None
): ):
"""
Initializes a new instance of the WebDriverSession.
Args:
webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
user_data_dir (str, optional): Directory for user data. Defaults to None.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
proxy (str, optional): Proxy settings for the browser. Defaults to None.
options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
"""
self.webdriver = webdriver self.webdriver = webdriver
self.user_data_dir = user_data_dir self.user_data_dir = user_data_dir
self.headless = headless self.headless = headless
@ -110,14 +132,17 @@ class WebDriverSession:
virtual_display: bool = False virtual_display: bool = False
) -> WebDriver: ) -> WebDriver:
""" """
Reopens the WebDriver session with the specified parameters. Reopens the WebDriver session with new settings.
:param user_data_dir: Directory for user data. Args:
:param headless: Boolean indicating whether to run the browser in headless mode. user_data_dir (str, optional): Directory for user data. Defaults to current value.
:param virtual_display: Boolean indicating whether to use a virtual display. headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
:return: An instance of WebDriver. virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
Returns:
WebDriver: The reopened WebDriver instance.
""" """
user_data_dir = user_data_dir or self.user_data_dir user_data_dir = user_data_data_dir or self.user_data_dir
if self.default_driver: if self.default_driver:
self.default_driver.quit() self.default_driver.quit()
if not virtual_display and self.virtual_display: if not virtual_display and self.virtual_display:
@ -128,8 +153,10 @@ class WebDriverSession:
def __enter__(self) -> WebDriver: def __enter__(self) -> WebDriver:
""" """
Context management method for entering a session. Context management method for entering a session. Initializes and returns a WebDriver instance.
:return: An instance of WebDriver.
Returns:
WebDriver: An instance of WebDriver for this session.
""" """
if self.webdriver: if self.webdriver:
return self.webdriver return self.webdriver
@ -141,6 +168,14 @@ class WebDriverSession:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
""" """
Context management method for exiting a session. Closes and quits the WebDriver. Context management method for exiting a session. Closes and quits the WebDriver.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Note:
Closes the WebDriver and stops the virtual display if used.
""" """
if self.default_driver: if self.default_driver:
try: try: