mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Refactor search and response handling; introduce CachedSearch and DDGS classes for improved web search functionality and response management. Add PlainTextResponse for handling plain text responses. Update requirements and setup for new dependencies.
This commit is contained in:
parent
305f47314f
commit
6b210f44f9
17 changed files with 505 additions and 313 deletions
|
|
@ -12,9 +12,6 @@ try:
|
|||
from .needs_auth.mini_max import HailuoAI, MiniMax
|
||||
except ImportError as e:
|
||||
debug.error("MiniMax providers not loaded:", e)
|
||||
|
||||
from .template import OpenaiTemplate, BackendApi
|
||||
from .qwen.QwenCode import QwenCode
|
||||
try:
|
||||
from .not_working import *
|
||||
except ImportError as e:
|
||||
|
|
@ -36,6 +33,8 @@ try:
|
|||
except ImportError as e:
|
||||
debug.error("Search providers not loaded:", e)
|
||||
|
||||
from .template import OpenaiTemplate, BackendApi
|
||||
from .qwen.QwenCode import QwenCode
|
||||
from .deprecated.ARTA import ARTA
|
||||
from .deprecated.Blackbox import Blackbox
|
||||
from .deprecated.DuckDuckGo import DuckDuckGo
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ except ImportError:
|
|||
from ...typing import AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
||||
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
|
||||
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning
|
||||
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning, PlainTextResponse, JsonRequest
|
||||
from ...tools.media import merge_media
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
|
||||
from ..helper import get_last_user_message
|
||||
|
|
@ -675,6 +675,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
],
|
||||
"modality": "image" if is_image_model else "chat"
|
||||
}
|
||||
yield JsonRequest.from_dict(data)
|
||||
try:
|
||||
async with StreamSession(**args, timeout=timeout) as session:
|
||||
async with session.post(
|
||||
|
|
@ -686,6 +687,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
args["cookies"] = merge_cookies(args["cookies"], response)
|
||||
async for chunk in response.iter_lines():
|
||||
line = chunk.decode()
|
||||
yield PlainTextResponse(line)
|
||||
if line.startswith("af:"):
|
||||
yield JsonConversation(message_ids=[modelAMessageId])
|
||||
elif line.startswith("a0:"):
|
||||
|
|
@ -693,6 +695,9 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
if chunk == "hasArenaError":
|
||||
raise ModelNotFoundError("LMArena Beta encountered an error: hasArenaError")
|
||||
yield chunk
|
||||
elif line.startswith("ag:"):
|
||||
chunk = json.loads(line[3:])
|
||||
yield Reasoning(chunk)
|
||||
elif line.startswith("a2:"):
|
||||
yield ImageResponse([image.get("image") for image in json.loads(line[3:])], prompt)
|
||||
elif line.startswith("ad:"):
|
||||
|
|
|
|||
103
g4f/Provider/search/CachedSearch.py
Normal file
103
g4f/Provider/search/CachedSearch.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from datetime import date
|
||||
|
||||
from ...typing import AsyncResult, Messages, Optional
|
||||
from ..base_provider import AsyncGeneratorProvider, AuthFileMixin
|
||||
from ...cookies import get_cookies_dir
|
||||
from ..helper import format_media_prompt
|
||||
from .DDGS import DDGS, SearchResults, SearchResultEntry
|
||||
from .SearXNG import SearXNG
|
||||
from ... import debug
|
||||
|
||||
async def search(
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
max_words: int = 2500,
|
||||
backend: str = "auto",
|
||||
add_text: bool = True,
|
||||
timeout: int = 5,
|
||||
region: str = "us-en",
|
||||
provider: str = "DDG"
|
||||
) -> SearchResults:
|
||||
"""
|
||||
Performs a web search and returns search results.
|
||||
"""
|
||||
if provider == "SearXNG":
|
||||
debug.log(f"[SearXNG] Using local container for query: {query}")
|
||||
results_texts = []
|
||||
async for chunk in SearXNG.create_async_generator(
|
||||
"SearXNG",
|
||||
[{"role": "user", "content": query}],
|
||||
max_results=max_results,
|
||||
max_words=max_words,
|
||||
add_text=add_text
|
||||
):
|
||||
if isinstance(chunk, str):
|
||||
results_texts.append(chunk)
|
||||
used_words = sum(text.count(" ") for text in results_texts)
|
||||
return SearchResults([
|
||||
SearchResultEntry(
|
||||
title=f"Result {i + 1}",
|
||||
url="",
|
||||
snippet=text,
|
||||
text=text
|
||||
) for i, text in enumerate(results_texts)
|
||||
], used_words=used_words)
|
||||
|
||||
return await anext(DDGS.create_async_generator(
|
||||
provider,
|
||||
[],
|
||||
prompt=query,
|
||||
max_results=max_results,
|
||||
max_words=max_words,
|
||||
add_text=add_text,
|
||||
timeout=timeout,
|
||||
region=region,
|
||||
backend=backend
|
||||
))
|
||||
|
||||
class CachedSearch(AsyncGeneratorProvider, AuthFileMixin):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
Combines search results with the user prompt, using caching for improved efficiency.
|
||||
"""
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
search_parameters = ["max_results", "max_words", "add_text", "timeout", "region"]
|
||||
search_parameters = {k: v for k, v in kwargs.items() if k in search_parameters}
|
||||
json_bytes = json.dumps({"model": model, "query": prompt, **search_parameters}, sort_keys=True).encode(errors="ignore")
|
||||
md5_hash = hashlib.md5(json_bytes).hexdigest()
|
||||
cache_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search" / f"{date.today()}"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file = cache_dir / f"{quote_plus(prompt[:20])}.{md5_hash}.cache"
|
||||
|
||||
search_results: Optional[SearchResults] = None
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
try:
|
||||
search_results = SearchResults.from_dict(json.loads(f.read()))
|
||||
except json.JSONDecodeError:
|
||||
search_results = None
|
||||
|
||||
if search_results is None:
|
||||
if model:
|
||||
search_parameters["provider"] = model
|
||||
search_results = await search(prompt, **search_parameters)
|
||||
if search_results.results:
|
||||
with cache_file.open("w") as f:
|
||||
f.write(json.dumps(search_results.get_dict()))
|
||||
|
||||
yield search_results
|
||||
228
g4f/Provider/search/DDGS.py
Normal file
228
g4f/Provider/search/DDGS.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional
|
||||
from urllib.parse import urlparse, quote_plus
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientError
|
||||
from datetime import date
|
||||
import asyncio
|
||||
|
||||
# Optional dependencies using the new 'ddgs' package name
|
||||
try:
|
||||
from ddgs import DDGS as DDGSClient
|
||||
from bs4 import BeautifulSoup
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
|
||||
from ...typing import Messages, AsyncResult
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...providers.response import format_link, JsonMixin, Sources
|
||||
from ...errors import MissingRequirementsError
|
||||
from ...providers.base_provider import AsyncGeneratorProvider
|
||||
from ..helper import format_media_prompt
|
||||
|
||||
def scrape_text(html: str, max_words: Optional[int] = None, add_source: bool = True, count_images: int = 2) -> Iterator[str]:
|
||||
"""
|
||||
Parses the provided HTML and yields text fragments.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main", ".main-content-wrapper", ".main-content", ".emt-container-inner",
|
||||
".content-wrapper", "#content", "#mainContent",
|
||||
]:
|
||||
selected = soup.select_one(selector)
|
||||
if selected:
|
||||
soup = selected
|
||||
break
|
||||
|
||||
for remove_selector in [".c-globalDisclosure"]:
|
||||
unwanted = soup.select_one(remove_selector)
|
||||
if unwanted:
|
||||
unwanted.extract()
|
||||
|
||||
image_selector = "img[alt][src^=http]:not([alt='']):not(.avatar):not([width])"
|
||||
image_link_selector = f"a:has({image_selector})"
|
||||
seen_texts = []
|
||||
|
||||
for element in soup.select(f"h1, h2, h3, h4, h5, h6, p, pre, table:not(:has(p)), ul:not(:has(p)), {image_link_selector}"):
|
||||
if count_images > 0:
|
||||
image = element.select_one(image_selector)
|
||||
if image:
|
||||
title = str(element.get("title", element.text))
|
||||
if title:
|
||||
yield f"!{format_link(image['src'], title)}\n"
|
||||
if max_words is not None:
|
||||
max_words -= 10
|
||||
count_images -= 1
|
||||
continue
|
||||
|
||||
for line in element.get_text(" ").splitlines():
|
||||
words = [word for word in line.split() if word]
|
||||
if not words:
|
||||
continue
|
||||
joined_line = " ".join(words)
|
||||
if joined_line in seen_texts:
|
||||
continue
|
||||
if max_words is not None:
|
||||
max_words -= len(words)
|
||||
if max_words <= 0:
|
||||
break
|
||||
yield joined_line + "\n"
|
||||
seen_texts.append(joined_line)
|
||||
|
||||
if add_source:
|
||||
canonical_link = soup.find("link", rel="canonical")
|
||||
if canonical_link and "href" in canonical_link.attrs:
|
||||
link = canonical_link["href"]
|
||||
domain = urlparse(link).netloc
|
||||
yield f"\nSource: [{domain}]({link})"
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: Optional[int] = None, add_source: bool = False, proxy: str = None) -> str:
|
||||
"""
|
||||
Fetches a URL and returns the scraped text, using caching to avoid redundant downloads.
|
||||
"""
|
||||
try:
|
||||
cache_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
md5_hash = hashlib.md5(url.encode(errors="ignore")).hexdigest()
|
||||
cache_file = cache_dir / f"{quote_plus(url.split('?')[0].split('//')[1].replace('/', ' ')[:48])}.{date.today()}.{md5_hash[:16]}.cache"
|
||||
if cache_file.exists():
|
||||
return cache_file.read_text()
|
||||
|
||||
async with session.get(url, proxy=proxy) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text(errors="replace")
|
||||
scraped_text = "".join(scrape_text(html, max_words, add_source))
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(scraped_text.encode(errors="replace"))
|
||||
return scraped_text
|
||||
except (ClientError, asyncio.TimeoutError):
|
||||
return ""
|
||||
return ""
|
||||
|
||||
class SearchResults(JsonMixin):
|
||||
"""
|
||||
Represents a collection of search result entries along with the count of used words.
|
||||
"""
|
||||
def __init__(self, results: List[SearchResultEntry], used_words: int):
|
||||
self.results = results
|
||||
self.used_words = used_words
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> SearchResults:
|
||||
return cls(
|
||||
[SearchResultEntry(**item) for item in data["results"]],
|
||||
data["used_words"]
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[SearchResultEntry]:
|
||||
yield from self.results
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Build a string representation of the search results with markdown formatting.
|
||||
output = []
|
||||
for idx, result in enumerate(self.results):
|
||||
parts = [
|
||||
f"### Title: {result.title}",
|
||||
"",
|
||||
result.text if result.text else result.snippet,
|
||||
"",
|
||||
f"> **Source:** [[{idx}]]({result.url})"
|
||||
]
|
||||
output.append("\n".join(parts))
|
||||
return "\n\n\n\n".join(output)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
def get_sources(self) -> Sources:
|
||||
return Sources([{"url": result.url, "title": result.title} for result in self.results])
|
||||
|
||||
def get_dict(self) -> dict:
|
||||
return {
|
||||
"results": [result.get_dict() for result in self.results],
|
||||
"used_words": self.used_words
|
||||
}
|
||||
|
||||
class SearchResultEntry(JsonMixin):
|
||||
"""
|
||||
Represents a single search result entry.
|
||||
"""
|
||||
def __init__(self, title: str, url: str, snippet: str, text: Optional[str] = None):
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.snippet = snippet
|
||||
self.text = text
|
||||
|
||||
def set_text(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
class DDGS(AsyncGeneratorProvider):
|
||||
working = has_requirements
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
timeout: int = 30,
|
||||
region: str = None,
|
||||
backend: str = None,
|
||||
max_results: int = 5,
|
||||
max_words: int = 2500,
|
||||
add_text: bool = True,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install "ddgs" and "beautifulsoup4" | pip install -U g4f[search]')
|
||||
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
results: List[SearchResultEntry] = []
|
||||
|
||||
# Use the new DDGS() context manager style
|
||||
with DDGSClient() as ddgs:
|
||||
for result in ddgs.text(
|
||||
prompt,
|
||||
region=region,
|
||||
safesearch="moderate",
|
||||
timelimit="y",
|
||||
max_results=max_results,
|
||||
backend=backend,
|
||||
):
|
||||
if ".google." in result["href"]:
|
||||
continue
|
||||
results.append(SearchResultEntry(
|
||||
title=result["title"],
|
||||
url=result["href"],
|
||||
snippet=result["body"]
|
||||
))
|
||||
|
||||
if add_text:
|
||||
tasks = []
|
||||
async with ClientSession(timeout=ClientTimeout(timeout)) as session:
|
||||
for entry in results:
|
||||
tasks.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)), False, proxy=proxy))
|
||||
texts = await asyncio.gather(*tasks)
|
||||
|
||||
formatted_results: List[SearchResultEntry] = []
|
||||
used_words = 0
|
||||
left_words = max_words
|
||||
for i, entry in enumerate(results):
|
||||
if add_text:
|
||||
entry.text = texts[i]
|
||||
left_words -= entry.title.count(" ") + 5
|
||||
if entry.text:
|
||||
left_words -= entry.text.count(" ")
|
||||
else:
|
||||
left_words -= entry.snippet.count(" ")
|
||||
if left_words < 0:
|
||||
break
|
||||
used_words = max_words - left_words
|
||||
formatted_results.append(entry)
|
||||
|
||||
yield SearchResults(formatted_results, used_words)
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from ...typing import Messages, AsyncResult
|
||||
from ...providers.base_provider import AsyncGeneratorProvider
|
||||
from ...providers.response import FinishReason
|
||||
from ...tools.web_search import fetch_and_scrape
|
||||
from ..helper import format_media_prompt
|
||||
from .DDGS import fetch_and_scrape
|
||||
from ... import debug
|
||||
|
||||
class SearXNG(AsyncGeneratorProvider):
|
||||
|
|
@ -20,7 +23,7 @@ class SearXNG(AsyncGeneratorProvider):
|
|||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
timeout: int = 30,
|
||||
language: str = "it",
|
||||
language: str = None,
|
||||
max_results: int = 5,
|
||||
max_words: int = 2500,
|
||||
add_text: bool = True,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .CachedSearch import CachedSearch
|
||||
from .GoogleSearch import GoogleSearch
|
||||
from .SearXNG import SearXNG
|
||||
from .YouTube import YouTube
|
||||
|
|
|
|||
|
|
@ -278,6 +278,8 @@ class Api:
|
|||
yield self._format_json("request", chunk.get_dict())
|
||||
elif isinstance(chunk, JsonResponse):
|
||||
yield self._format_json("response", chunk.get_dict())
|
||||
elif isinstance(chunk, PlainTextResponse):
|
||||
yield self._format_json("response", chunk.text)
|
||||
else:
|
||||
yield self._format_json("content", str(chunk))
|
||||
except MissingAuthError as e:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from ..requests.aiohttp import get_connector
|
|||
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
|
||||
from ..tools.files import secure_filename
|
||||
from ..providers.response import ImageResponse, AudioResponse, VideoResponse, quote_url
|
||||
from ..Provider.template import BackendApi
|
||||
from . import is_accepted_format, extract_data_uri
|
||||
from .. import debug
|
||||
|
||||
|
|
@ -171,15 +170,8 @@ async def copy_media(
|
|||
with open(target_path, "wb") as f:
|
||||
f.write(extract_data_uri(image))
|
||||
elif not os.path.exists(target_path) or os.lstat(target_path).st_size <= 0:
|
||||
# Apply BackendApi settings if needed
|
||||
if BackendApi.working and image.startswith(BackendApi.url):
|
||||
request_headers = BackendApi.headers if headers is None else headers
|
||||
request_ssl = BackendApi.ssl
|
||||
else:
|
||||
request_headers = headers
|
||||
request_ssl = ssl
|
||||
# Use aiohttp to fetch the image
|
||||
async with session.get(image, ssl=request_ssl, headers=request_headers) as response:
|
||||
async with session.get(image, ssl=ssl) as response:
|
||||
response.raise_for_status()
|
||||
if target is None:
|
||||
filename = update_filename(response, filename)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from .response import BaseConversation, AuthResult
|
|||
from .helper import concat_chunks
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
|
||||
from ..tools.run_tools import AuthManager
|
||||
from ..tools.auth import AuthManager
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
|
|
|
|||
|
|
@ -231,10 +231,13 @@ class DebugResponse(HiddenResponse):
|
|||
"""Initialize with a log message."""
|
||||
self.log = log
|
||||
|
||||
class PlainTextResponse(HiddenResponse):
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
class ContinueResponse(HiddenResponse):
|
||||
def __init__(self, log: str) -> None:
|
||||
"""Initialize with a log message."""
|
||||
self.log = log
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
class Reasoning(ResponseType):
|
||||
def __init__(
|
||||
|
|
|
|||
32
g4f/tools/auth.py
Normal file
32
g4f/tools/auth.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ..providers.types import ProviderType
|
||||
from .. import debug
|
||||
|
||||
class AuthManager:
|
||||
"""Handles API key management"""
|
||||
aliases = {
|
||||
"GeminiPro": "Gemini",
|
||||
"PollinationsAI": "Pollinations",
|
||||
"OpenaiAPI": "Openai",
|
||||
"PuterJS": "Puter",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def load_api_key(cls, provider: ProviderType) -> Optional[str]:
|
||||
"""Load API key from config file"""
|
||||
if not provider.needs_auth and not hasattr(provider, "login_url"):
|
||||
return None
|
||||
provider_name = provider.get_parent()
|
||||
env_var = f"{provider_name.upper()}_API_KEY"
|
||||
api_key = os.environ.get(env_var)
|
||||
if not api_key and provider_name in cls.aliases:
|
||||
env_var = f"{cls.aliases[provider_name].upper()}_API_KEY"
|
||||
api_key = os.environ.get(env_var)
|
||||
if api_key:
|
||||
debug.log(f"Loading API key for {provider_name} from environment variable {env_var}")
|
||||
return api_key
|
||||
return None
|
||||
98
g4f/tools/fetch_and_scrape.py
Normal file
98
g4f/tools/fetch_and_scrape.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlparse, quote_plus
|
||||
from aiohttp import ClientSession, ClientError
|
||||
from datetime import date
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..providers.response import format_link
|
||||
|
||||
def scrape_text(html: str, max_words: Optional[int] = None, add_source: bool = True, count_images: int = 2) -> Iterator[str]:
|
||||
"""
|
||||
Parses the provided HTML and yields text fragments.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main", ".main-content-wrapper", ".main-content", ".emt-container-inner",
|
||||
".content-wrapper", "#content", "#mainContent",
|
||||
]:
|
||||
selected = soup.select_one(selector)
|
||||
if selected:
|
||||
soup = selected
|
||||
break
|
||||
|
||||
for remove_selector in [".c-globalDisclosure"]:
|
||||
unwanted = soup.select_one(remove_selector)
|
||||
if unwanted:
|
||||
unwanted.extract()
|
||||
|
||||
image_selector = "img[alt][src^=http]:not([alt='']):not(.avatar):not([width])"
|
||||
image_link_selector = f"a:has({image_selector})"
|
||||
seen_texts = []
|
||||
|
||||
for element in soup.select(f"h1, h2, h3, h4, h5, h6, p, pre, table:not(:has(p)), ul:not(:has(p)), {image_link_selector}"):
|
||||
if count_images > 0:
|
||||
image = element.select_one(image_selector)
|
||||
if image:
|
||||
title = str(element.get("title", element.text))
|
||||
if title:
|
||||
yield f"!{format_link(image['src'], title)}\n"
|
||||
if max_words is not None:
|
||||
max_words -= 10
|
||||
count_images -= 1
|
||||
continue
|
||||
|
||||
for line in element.get_text(" ").splitlines():
|
||||
words = [word for word in line.split() if word]
|
||||
if not words:
|
||||
continue
|
||||
joined_line = " ".join(words)
|
||||
if joined_line in seen_texts:
|
||||
continue
|
||||
if max_words is not None:
|
||||
max_words -= len(words)
|
||||
if max_words <= 0:
|
||||
break
|
||||
yield joined_line + "\n"
|
||||
seen_texts.append(joined_line)
|
||||
|
||||
if add_source:
|
||||
canonical_link = soup.find("link", rel="canonical")
|
||||
if canonical_link and "href" in canonical_link.attrs:
|
||||
link = canonical_link["href"]
|
||||
domain = urlparse(link).netloc
|
||||
yield f"\nSource: [{domain}]({link})"
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: Optional[int] = None, add_source: bool = False, proxy: str = None) -> str:
|
||||
"""
|
||||
Fetches a URL and returns the scraped text, using caching to avoid redundant downloads.
|
||||
"""
|
||||
try:
|
||||
cache_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
md5_hash = hashlib.md5(url.encode(errors="ignore")).hexdigest()
|
||||
cache_file = cache_dir / f"{quote_plus(url.split('?')[0].split('//')[1].replace('/', ' ')[:48])}.{date.today()}.{md5_hash[:16]}.cache"
|
||||
if cache_file.exists():
|
||||
return cache_file.read_text()
|
||||
|
||||
async with session.get(url, proxy=proxy) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text(errors="replace")
|
||||
scraped_text = "".join(scrape_text(html, max_words, add_source))
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(scraped_text.encode(errors="replace"))
|
||||
return scraped_text
|
||||
except (ClientError, asyncio.TimeoutError):
|
||||
return ""
|
||||
return ""
|
||||
|
|
@ -74,7 +74,7 @@ try:
|
|||
except ImportError:
|
||||
has_markitdown = False
|
||||
|
||||
from .web_search import scrape_text
|
||||
from .fetch_and_scrape import scrape_text
|
||||
from ..files import secure_filename, get_bucket_dir
|
||||
from ..image import is_allowed_extension
|
||||
from ..requests.aiohttp import get_connector
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from ..providers.response import Reasoning, FinishReason, Sources, Usage, Provid
|
|||
from ..providers.types import ProviderType
|
||||
from ..cookies import get_cookies_dir
|
||||
from .web_search import do_search, get_search_message
|
||||
from .auth import AuthManager
|
||||
from .files import read_bucket, get_bucket_dir
|
||||
from .. import debug
|
||||
|
||||
|
|
@ -130,31 +131,6 @@ class ToolHandler:
|
|||
|
||||
return messages, sources, extra_kwargs
|
||||
|
||||
class AuthManager:
|
||||
"""Handles API key management"""
|
||||
aliases = {
|
||||
"GeminiPro": "Gemini",
|
||||
"PollinationsAI": "Pollinations",
|
||||
"OpenaiAPI": "Openai",
|
||||
"PuterJS": "Puter",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def load_api_key(cls, provider: ProviderType) -> Optional[str]:
|
||||
"""Load API key from config file"""
|
||||
if not provider.needs_auth and not hasattr(provider, "login_url"):
|
||||
return None
|
||||
provider_name = provider.get_parent()
|
||||
env_var = f"{provider_name.upper()}_API_KEY"
|
||||
api_key = os.environ.get(env_var)
|
||||
if not api_key and provider_name in cls.aliases:
|
||||
env_var = f"{cls.aliases[provider_name].upper()}_API_KEY"
|
||||
api_key = os.environ.get(env_var)
|
||||
if api_key:
|
||||
debug.log(f"Loading API key for {provider_name} from environment variable {env_var}")
|
||||
return api_key
|
||||
return None
|
||||
|
||||
class ThinkingProcessor:
|
||||
"""Processes thinking chunks"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,32 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientError
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse, quote_plus
|
||||
from datetime import date
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
# Optional dependencies using the new 'ddgs' package name
|
||||
try:
|
||||
from ddgs import DDGS
|
||||
from ddgs.exceptions import DDGSException
|
||||
from bs4 import BeautifulSoup
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
from typing import Type as DDGSException
|
||||
|
||||
try:
|
||||
import spacy
|
||||
has_spacy = True
|
||||
except ImportError:
|
||||
has_spacy = False
|
||||
|
||||
from typing import Iterator, List, Optional
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..providers.response import format_link, JsonMixin, Sources
|
||||
from ..providers.response import Sources
|
||||
from ..errors import MissingRequirementsError
|
||||
from ..Provider.search.CachedSearch import CachedSearch
|
||||
from .. import debug
|
||||
|
||||
DEFAULT_INSTRUCTIONS = """
|
||||
|
|
@ -34,267 +18,30 @@ Using the provided web search results, to write a comprehensive reply to the use
|
|||
Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)
|
||||
"""
|
||||
|
||||
class SearchResults(JsonMixin):
|
||||
"""
|
||||
Represents a collection of search result entries along with the count of used words.
|
||||
"""
|
||||
def __init__(self, results: List[SearchResultEntry], used_words: int):
|
||||
self.results = results
|
||||
self.used_words = used_words
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> SearchResults:
|
||||
return cls(
|
||||
[SearchResultEntry(**item) for item in data["results"]],
|
||||
data["used_words"]
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[SearchResultEntry]:
|
||||
yield from self.results
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Build a string representation of the search results with markdown formatting.
|
||||
output = []
|
||||
for idx, result in enumerate(self.results):
|
||||
parts = [
|
||||
f"Title: {result.title}",
|
||||
"",
|
||||
result.text if result.text else result.snippet,
|
||||
"",
|
||||
f"Source: [[{idx}]]({result.url})"
|
||||
]
|
||||
output.append("\n".join(parts))
|
||||
return "\n\n\n".join(output)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
def get_sources(self) -> Sources:
|
||||
return Sources([{"url": result.url, "title": result.title} for result in self.results])
|
||||
|
||||
def get_dict(self) -> dict:
|
||||
return {
|
||||
"results": [result.get_dict() for result in self.results],
|
||||
"used_words": self.used_words
|
||||
}
|
||||
|
||||
class SearchResultEntry(JsonMixin):
|
||||
"""
|
||||
Represents a single search result entry.
|
||||
"""
|
||||
def __init__(self, title: str, url: str, snippet: str, text: Optional[str] = None):
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.snippet = snippet
|
||||
self.text = text
|
||||
|
||||
def set_text(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
def scrape_text(html: str, max_words: Optional[int] = None, add_source: bool = True, count_images: int = 2) -> Iterator[str]:
|
||||
"""
|
||||
Parses the provided HTML and yields text fragments.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main", ".main-content-wrapper", ".main-content", ".emt-container-inner",
|
||||
".content-wrapper", "#content", "#mainContent",
|
||||
]:
|
||||
selected = soup.select_one(selector)
|
||||
if selected:
|
||||
soup = selected
|
||||
break
|
||||
|
||||
for remove_selector in [".c-globalDisclosure"]:
|
||||
unwanted = soup.select_one(remove_selector)
|
||||
if unwanted:
|
||||
unwanted.extract()
|
||||
|
||||
image_selector = "img[alt][src^=http]:not([alt='']):not(.avatar):not([width])"
|
||||
image_link_selector = f"a:has({image_selector})"
|
||||
seen_texts = []
|
||||
|
||||
for element in soup.select(f"h1, h2, h3, h4, h5, h6, p, pre, table:not(:has(p)), ul:not(:has(p)), {image_link_selector}"):
|
||||
if count_images > 0:
|
||||
image = element.select_one(image_selector)
|
||||
if image:
|
||||
title = str(element.get("title", element.text))
|
||||
if title:
|
||||
yield f"!{format_link(image['src'], title)}\n"
|
||||
if max_words is not None:
|
||||
max_words -= 10
|
||||
count_images -= 1
|
||||
continue
|
||||
|
||||
for line in element.get_text(" ").splitlines():
|
||||
words = [word for word in line.split() if word]
|
||||
if not words:
|
||||
continue
|
||||
joined_line = " ".join(words)
|
||||
if joined_line in seen_texts:
|
||||
continue
|
||||
if max_words is not None:
|
||||
max_words -= len(words)
|
||||
if max_words <= 0:
|
||||
break
|
||||
yield joined_line + "\n"
|
||||
seen_texts.append(joined_line)
|
||||
|
||||
if add_source:
|
||||
canonical_link = soup.find("link", rel="canonical")
|
||||
if canonical_link and "href" in canonical_link.attrs:
|
||||
link = canonical_link["href"]
|
||||
domain = urlparse(link).netloc
|
||||
yield f"\nSource: [{domain}]({link})"
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: Optional[int] = None, add_source: bool = False) -> str:
|
||||
"""
|
||||
Fetches a URL and returns the scraped text, using caching to avoid redundant downloads.
|
||||
"""
|
||||
try:
|
||||
cache_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
md5_hash = hashlib.md5(url.encode(errors="ignore")).hexdigest()
|
||||
cache_file = cache_dir / f"{quote_plus(url.split('?')[0].split('//')[1].replace('/', ' ')[:48])}.{date.today()}.{md5_hash[:16]}.cache"
|
||||
if cache_file.exists():
|
||||
return cache_file.read_text()
|
||||
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text(errors="replace")
|
||||
scraped_text = "".join(scrape_text(html, max_words, add_source))
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(scraped_text.encode(errors="replace"))
|
||||
return scraped_text
|
||||
except (ClientError, asyncio.TimeoutError):
|
||||
return ""
|
||||
return ""
|
||||
|
||||
async def search(
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
max_words: int = 2500,
|
||||
backend: str = "auto",
|
||||
add_text: bool = True,
|
||||
timeout: int = 5,
|
||||
region: str = "us-en",
|
||||
provider: str = "DDG"
|
||||
) -> SearchResults:
|
||||
"""
|
||||
Performs a web search and returns search results.
|
||||
"""
|
||||
if provider == "SearXNG":
|
||||
from ..Provider.SearXNG import SearXNG
|
||||
debug.log(f"[SearXNG] Using local container for query: {query}")
|
||||
results_texts = []
|
||||
async for chunk in SearXNG.create_async_generator(
|
||||
"SearXNG",
|
||||
[{"role": "user", "content": query}],
|
||||
max_results=max_results,
|
||||
max_words=max_words,
|
||||
add_text=add_text
|
||||
):
|
||||
if isinstance(chunk, str):
|
||||
results_texts.append(chunk)
|
||||
used_words = sum(text.count(" ") for text in results_texts)
|
||||
return SearchResults([
|
||||
SearchResultEntry(
|
||||
title=f"Result {i + 1}",
|
||||
url="",
|
||||
snippet=text,
|
||||
text=text
|
||||
) for i, text in enumerate(results_texts)
|
||||
], used_words=used_words)
|
||||
|
||||
debug.log(f"[DuckDuckGo] Using local container for query: {query}")
|
||||
|
||||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install "ddgs" and "beautifulsoup4" | pip install -U g4f[search]')
|
||||
|
||||
results: List[SearchResultEntry] = []
|
||||
# Use the new DDGS() context manager style
|
||||
with DDGS() as ddgs:
|
||||
for result in ddgs.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch="moderate",
|
||||
timelimit="y",
|
||||
max_results=max_results,
|
||||
backend=backend,
|
||||
):
|
||||
if ".google." in result["href"]:
|
||||
continue
|
||||
results.append(SearchResultEntry(
|
||||
title=result["title"],
|
||||
url=result["href"],
|
||||
snippet=result["body"]
|
||||
))
|
||||
|
||||
if add_text:
|
||||
tasks = []
|
||||
async with ClientSession(timeout=ClientTimeout(timeout)) as session:
|
||||
for entry in results:
|
||||
tasks.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)), False))
|
||||
texts = await asyncio.gather(*tasks)
|
||||
|
||||
formatted_results: List[SearchResultEntry] = []
|
||||
used_words = 0
|
||||
left_words = max_words
|
||||
for i, entry in enumerate(results):
|
||||
if add_text:
|
||||
entry.text = texts[i]
|
||||
left_words -= entry.title.count(" ") + 5
|
||||
if entry.text:
|
||||
left_words -= entry.text.count(" ")
|
||||
else:
|
||||
left_words -= entry.snippet.count(" ")
|
||||
if left_words < 0:
|
||||
break
|
||||
used_words = max_words - left_words
|
||||
formatted_results.append(entry)
|
||||
|
||||
return SearchResults(formatted_results, used_words)
|
||||
|
||||
async def do_search(
|
||||
prompt: str,
|
||||
query: Optional[str] = None,
|
||||
instructions: str = DEFAULT_INSTRUCTIONS,
|
||||
**kwargs
|
||||
) -> tuple[str, Optional[Sources]]:
|
||||
"""
|
||||
Combines search results with the user prompt, using caching for improved efficiency.
|
||||
"""
|
||||
if not isinstance(prompt, str):
|
||||
return prompt, None
|
||||
if not prompt or not isinstance(prompt, str):
|
||||
return
|
||||
|
||||
if instructions and instructions in prompt:
|
||||
return prompt, None
|
||||
return
|
||||
|
||||
if prompt.startswith("##") and query is None:
|
||||
return prompt, None
|
||||
return
|
||||
|
||||
if query is None:
|
||||
query = prompt.strip().splitlines()[0]
|
||||
|
||||
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")
|
||||
md5_hash = hashlib.md5(json_bytes).hexdigest()
|
||||
cache_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search" / f"{date.today()}"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file = cache_dir / f"{quote_plus(query[:20])}.{md5_hash}.cache"
|
||||
|
||||
search_results: Optional[SearchResults] = None
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
try:
|
||||
search_results = SearchResults.from_dict(json.loads(f.read()))
|
||||
except json.JSONDecodeError:
|
||||
search_results = None
|
||||
|
||||
if search_results is None:
|
||||
search_results = await search(query, **kwargs)
|
||||
if search_results.results:
|
||||
with cache_file.open("w") as f:
|
||||
f.write(json.dumps(search_results.get_dict()))
|
||||
search_results = await anext(CachedSearch.create_async_generator(
|
||||
"",
|
||||
[],
|
||||
prompt=query,
|
||||
**kwargs
|
||||
))
|
||||
|
||||
if instructions:
|
||||
new_prompt = f"{search_results}\n\nInstruction: {instructions}\n\nUser request:\n{prompt}"
|
||||
|
|
@ -303,6 +50,7 @@ async def do_search(
|
|||
|
||||
debug.log(f"Web search: '{query.strip()[:50]}...'")
|
||||
debug.log(f"with {len(search_results.results)} Results {search_results.used_words} Words")
|
||||
|
||||
return new_prompt.strip(), search_results.get_sources()
|
||||
|
||||
def get_search_message(prompt: str, raise_search_exceptions: bool = False, **kwargs) -> str:
|
||||
|
|
|
|||
|
|
@ -19,3 +19,4 @@ python-multipart
|
|||
a2wsgi
|
||||
python-dotenv
|
||||
ddgs
|
||||
aiofile
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -38,7 +38,8 @@ EXTRA_REQUIRE = {
|
|||
"plyer",
|
||||
"setuptools",
|
||||
"markitdown[all]",
|
||||
"python-dotenv"
|
||||
"python-dotenv",
|
||||
"aiofile"
|
||||
],
|
||||
'slim': [
|
||||
"curl_cffi>=0.6.2",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue