Add new media selection in UI

Add HuggingFace provider provider
Auto refresh Google Gemini cookies
Add sources to search results
This commit is contained in:
hlohaus 2025-02-26 11:41:00 +01:00
parent 69ab91a63f
commit 1d3a139a53
13 changed files with 326 additions and 105 deletions

View file

@ -835,6 +835,7 @@ A list of all contributors is available [here](https://github.com/xtekky/gpt4fre
- The [`Gemini.py`](https://github.com/xtekky/gpt4free/blob/main/g4f/Provider/needs_auth/Gemini.py) has input from [dsdanielpark/Gemini-API](https://github.com/dsdanielpark/Gemini-API)
- The [`MetaAI.py`](https://github.com/xtekky/gpt4free/blob/main/g4f/Provider/MetaAI.py) file contains code from [meta-ai-api](https://github.com/Strvm/meta-ai-api) by [@Strvm](https://github.com/Strvm)
- The [`proofofwork.py`](https://github.com/xtekky/gpt4free/blob/main/g4f/Provider/openai/proofofwork.py) has input from [missuo/FreeGPT35](https://github.com/missuo/FreeGPT35)
- The [`Gemini.py`](https://github.com/xtekky/gpt4free/blob/main/g4f/Provider/needs_auth/Gemini.py) has input from [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API)
_Having input implies that the AI's code generation utilized it as one of many sources._

View file

@ -14,6 +14,11 @@ from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
from ... import debug
provider_together_urls = {
"black-forest-labs/FLUX.1-dev": "https://router.huggingface.co/together/v1/images/generations",
"black-forest-labs/FLUX.1-schnell": "https://router.huggingface.co/together/v1/images/generations",
}
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
parent = "HuggingFace"
@ -63,6 +68,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
stream: bool = True,
proxy: str = None,
timeout: int = 600,
api_base: str = "https://api-inference.huggingface.co",
api_key: str = None,
max_tokens: int = 1024,
@ -71,6 +77,8 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
action: str = None,
extra_data: dict = {},
seed: int = None,
width: int = 1024,
height: int = 1024,
**kwargs
) -> AsyncResult:
try:
@ -78,23 +86,35 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
except ModelNotSupportedError:
pass
headers = {
'accept': '*/*',
'accept-language': 'en',
'cache-control': 'no-cache',
'origin': 'https://huggingface.co',
'pragma': 'no-cache',
'priority': 'u=1, i',
'referer': 'https://huggingface.co/chat/',
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
'Accept-Encoding': 'gzip, deflate',
'Content-Type': 'application/json',
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with StreamSession(
headers=headers,
proxy=proxy,
timeout=timeout
) as session:
try:
if model in provider_together_urls:
data = {
"response_format": "url",
"prompt": format_image_prompt(messages, prompt),
"model": model,
"width": width,
"height": height,
**extra_data
}
async with session.post(provider_together_urls[model], json=data) as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
result = await response.json()
yield ImageResponse([item["url"] for item in result["data"]], data["prompt"])
return
except ModelNotSupportedError:
pass
payload = None
params = {
"return_full_text": False,
@ -103,11 +123,6 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
**extra_data
}
do_continue = action == "continue"
async with StreamSession(
headers=headers,
proxy=proxy,
timeout=600
) as session:
if payload is None:
model_data = await cls.get_model_data(session, model)
pipeline_tag = model_data.get("pipeline_tag")

View file

@ -6,7 +6,10 @@ import random
import re
import base64
import asyncio
import time
from urllib.parse import quote_plus, unquote_plus
from pathlib import Path
from aiohttp import ClientSession, BaseConnector
try:
@ -17,15 +20,15 @@ except ImportError:
from ... import debug
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...requests import get_nodriver
from ...errors import MissingAuthError
from ...image import to_bytes
from ..helper import get_last_user_message
from ...cookies import get_cookies_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies, get_last_user_message
from ... import debug
REQUEST_HEADERS = {
@ -52,6 +55,9 @@ UPLOAD_IMAGE_HEADERS = {
"x-goog-upload-protocol": "resumable",
"x-tenant-id": "bard-storage",
}
GOOGLE_COOKIE_DOMAIN = ".google.com"
ROTATE_COOKIES_URL = "https://accounts.google.com/RotateCookies"
GGOGLE_SID_COOKIE = "__Secure-1PSID"
models = {
"gemini-2.0-flash": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f299729663a2343f"]'},
@ -87,6 +93,10 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
_snlm0e: str = None
_sid: str = None
auto_refresh = True
refresh_interval = 540
rotate_tasks = {}
@classmethod
async def nodriver_login(cls, proxy: str = None) -> AsyncIterator[str]:
if not has_nodriver:
@ -108,6 +118,29 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
finally:
stop_browser()
@classmethod
async def start_auto_refresh(cls, proxy: str = None) -> None:
"""
Start the background task to automatically refresh cookies.
"""
while True:
try:
new_1psidts = await rotate_1psidts(cls.url, cls._cookies, proxy)
except Exception as e:
debug.error(f"Failed to refresh cookies: {e}")
task = cls.rotate_tasks.get(cls._cookies[GGOGLE_SID_COOKIE])
if task:
task.cancel()
debug.error(
"Failed to refresh cookies. Background auto refresh task canceled."
)
debug.log(f"Gemini: Cookies refreshed. New __Secure-1PSIDTS: {new_1psidts}")
if new_1psidts:
cls._cookies["__Secure-1PSIDTS"] = new_1psidts
await asyncio.sleep(cls.refresh_interval)
@classmethod
async def create_async_generator(
cls,
@ -122,8 +155,10 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
language: str = "en",
**kwargs
) -> AsyncResult:
cls._cookies = cookies or cls._cookies or get_cookies(GOOGLE_COOKIE_DOMAIN, False, True)
if conversation is not None and getattr(conversation, "model", None) != model:
conversation = None
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
base_connector = get_connector(connector, proxy)
async with ClientSession(
@ -144,6 +179,12 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
await cls.fetch_snlm0e(session, cls._cookies)
if not cls._snlm0e:
raise RuntimeError("Invalid cookies. SNlM0e not found")
if GGOGLE_SID_COOKIE in cls._cookies:
task = cls.rotate_tasks.get(cls._cookies[GGOGLE_SID_COOKIE])
if not task:
cls.rotate_tasks[cls._cookies[GGOGLE_SID_COOKIE]] = asyncio.create_task(
cls.start_auto_refresh()
)
images = await cls.upload_images(base_connector, images) if images else None
async with ClientSession(
@ -190,7 +231,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
if not response_part[4]:
continue
if return_conversation:
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0], model)
def read_recusive(data):
for item in data:
if isinstance(item, list):
@ -222,12 +263,13 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
if match:
image_prompt = match.group(1)
content = content.replace(match.group(0), '')
pattern = r"http://googleusercontent.com/(?:image_generation|youtube)_content/\d+"
pattern = r"http://googleusercontent.com/(?:image_generation|youtube|map)_content/\d+"
content = re.sub(pattern, "", content)
content = content.replace("<!-- end list -->", "")
content = content.replace("https://www.google.com/search?q=http://", "https://")
content = content.replace("https://www.google.com/search?q=https://", "https://")
content = content.replace("https://www.google.com/url?sa=E&source=gmail&q=http://", "http://")
def replace_link(match):
return f"(https://{quote_plus(unquote_plus(match.group(1)), '/?&=#')})"
content = re.sub(r"\(https://www.google.com/(?:search\?q=|url\?sa=E&source=gmail&q=)https?://(.+?)\)", replace_link, content)
if last_content and content.startswith(last_content):
yield content[len(last_content):]
else:
@ -240,6 +282,13 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
yield ImageResponse(images, image_prompt, {"cookies": cls._cookies})
except (TypeError, IndexError, KeyError):
pass
youtube_ids = []
pattern = re.compile(r"http://www.youtube.com/watch\?v=(\w+)")
for match in pattern.finditer(content):
if match.group(1) not in youtube_ids:
youtube_ids.append(match.group(1))
if youtube_ids:
yield YouTube(youtube_ids)
@classmethod
async def synthesize(cls, params: dict, proxy: str = None) -> AsyncIterator[bytes]:
@ -354,11 +403,13 @@ class Conversation(JsonConversation):
def __init__(self,
conversation_id: str,
response_id: str,
choice_id: str
choice_id: str,
model: str
) -> None:
self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id
self.model = model
async def iter_filter_base64(chunks: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
search_for = b'[["wrb.fr","XqA3Ic","[\\"'
@ -387,3 +438,34 @@ async def iter_base64_decode(chunks: AsyncIterator[bytes]) -> AsyncIterator[byte
yield base64.b64decode(chunk[:-rest])
if rest > 0:
yield base64.b64decode(buffer+rest*b"=")
async def rotate_1psidts(url, cookies: dict, proxy: str | None = None) -> str:
path = Path(get_cookies_dir())
path.mkdir(parents=True, exist_ok=True)
filename = f"auth_Gemini.json"
path = path / filename
# Check if the cache file was modified in the last minute to avoid 429 Too Many Requests
if not (path.is_file() and time.time() - os.path.getmtime(path) <= 60):
async with ClientSession(proxy=proxy) as client:
response = await client.post(
url=ROTATE_COOKIES_URL,
headers={
"Content-Type": "application/json",
},
cookies=cookies,
data='[000,"-0000000000000000000"]',
)
if response.status == 401:
raise MissingAuthError("Invalid cookies")
response.raise_for_status()
for key, c in response.cookies.items():
cookies[key] = c.value
new_1psidts = response.cookies.get("__Secure-1PSIDTS")
path.write_text(json.dumps([{
"name": k,
"value": v,
"domain": GOOGLE_COOKIE_DOMAIN,
} for k, v in cookies.items()]))
if new_1psidts:
return new_1psidts

View file

@ -192,5 +192,5 @@ def read_cookie_files(dirPath: str = None):
new_cookies[c["domain"]] = {}
new_cookies[c["domain"]][c["name"]] = c["value"]
for domain, new_values in new_cookies.items():
debug.log(f"Cookies added: {len(new_values)} from {domain}")
CookiesConfig.cookies[domain] = new_values
debug.log(f"Cookies added: {len(new_values)} from {domain}")

View file

@ -254,6 +254,13 @@
} catch(e) {
console.log(e);
input.setCustomValidity("Invalid Access Token.");
localStorage.removeItem("HuggingFace-api_key");
if (localStorage.getItem("oauth")) {
window.location.href = (await oauthLoginUrl({
clientId: 'ed074164-4f8d-4fb2-8bec-44952707965e',
scopes: ['inference-api']
}));
}
return;
}
localStorage.setItem("HuggingFace-api_key", accessToken);
@ -289,10 +296,13 @@
window.location.reload();
}
} else {
localStorage.removeItem("oauth");
document.getElementById("signin").style.removeProperty("display");
document.getElementById("signin").onclick = async function() {
// prompt=consent to re-trigger the consent screen instead of silently redirecting
window.location.href = (await oauthLoginUrl({clientId: 'ed074164-4f8d-4fb2-8bec-44952707965e', scopes: ['inference-api']})) + "&prompt=consent";
window.location.href = (await oauthLoginUrl({
clientId: 'ed074164-4f8d-4fb2-8bec-44952707965e',
scopes: ['inference-api']
}));
}
}
</script>

View file

@ -33,6 +33,12 @@
};
</script>
<script id="MathJax-script" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" async></script>
<script>
var tag = document.createElement('script');
tag.src = "https://www.youtube.com/iframe_api";
var firstScriptTag = document.getElementsByTagName('script')[0];
firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);
</script>
<template>
<script type="module" src="https://cdn.jsdelivr.net/npm/mistral-tokenizer-js" async>
import mistralTokenizer from "mistral-tokenizer-js"
@ -211,6 +217,19 @@
<div class="media_player">
<i class="fa-regular fa-x"></i>
</div>
<div class="media-select hidden">
<label class="image-select" for="image" title="">
<input type="file" id="image" name="image" accept="image/*" required/>
<i class="fa-regular fa-image"></i>
</label>
<label class="capture-camera" for="camera">
<input type="file" id="camera" name="camera" accept="image/*" capture="camera" required/>
<i class="fa-solid fa-camera"></i>
</label>
<button class="close">
<i class="fa-solid fa-xmark"></i>
</button>
</div>
<div class="toolbar">
<div id="input-count" class="">
<button class="hide-input">
@ -236,14 +255,9 @@
<div class="box input-area">
<textarea id="message-input" placeholder="Ask a question" cols="30" rows="10"
style="white-space: pre-wrap;resize: none;"></textarea>
<label class="file-label image-label" for="image" title="">
<input type="file" id="image" name="image" accept="image/*" required/>
<label class="file-label image-label">
<i class="fa-regular fa-image"></i>
</label>
<label class="file-label image-label" for="camera">
<input type="file" id="camera" name="camera" accept="image/*" capture="camera" required/>
<i class="fa-solid fa-camera"></i>
</label>
<label class="file-label" for="file">
<input type="file" id="file" name="file" accept=".txt, .html, .xml, .json, .js, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md, .pdf, .docx, .odt, .epub, .xlsx, .zip" required multiple/>
<i class="fa-solid fa-paperclip"></i>

View file

@ -384,7 +384,7 @@ body:not(.white) a:visited{
}
.message .reasoning_text.final {
max-height: 1000px;
max-height: 2000px;
transition: max-height 0.25s ease-in;
}
@ -478,6 +478,13 @@ body:not(.white) a:visited{
padding: 0 4px;
}
.message .content blockquote {
padding: 8px 16px;
margin-bottom: 16px;
color: inherit;
border-left: .25em solid var(--colour-4);
}
.media_player {
display: none;
}
@ -501,6 +508,36 @@ body:not(.white) a:visited{
cursor: pointer;
}
.media-select {
display: flex;
flex-direction: row;
}
.media-select label, .media-select img, .media-select button {
display: flex;
gap: 18px;
align-items: center;
cursor: pointer;
user-select: none;
color: var(--colour-1);
background: var(--colour-4);
border: 1px solid var(--colour-1);
transition: all 0.2s ease;
width: auto;
height: 60px;
margin: 2px;
}
.media-select label, .media-select button {
padding: 8px 12px;
border-radius: var(--border-radius-1);
}
.media-select button.close {
order: 1000;
height: 32px;
}
.count_total {
font-size: 12px;
padding-left: 25px;
@ -691,7 +728,7 @@ input-count .text {
border-color: var(--accent);
}
label[for="image"] {
label.image-label {
top: 32px;
}
@ -699,11 +736,6 @@ label[for="micro"] {
top: 54px;
}
label[for="camera"] {
top: 74px;
display: none;
}
@media (pointer:none), (pointer:coarse) {
label[for="camera"] {
display: block;

View file

@ -7,7 +7,9 @@ const regenerate_button = document.querySelector(`.regenerate`);
const sidebar = document.querySelector(".conversations");
const sidebar_button = document.querySelector(".mobile-sidebar");
const sendButton = document.getElementById("send-button");
const imageInput = document.getElementById("image");
const imageInput = document.querySelector(".image-label");
const mediaSelect = document.querySelector(".media-select");
const imageSelect = document.getElementById("image");
const cameraInput = document.getElementById("camera");
const fileInput = document.getElementById("file");
const microLabel = document.querySelector(".micro-label");
@ -40,6 +42,7 @@ let usage_storage = {};
let reasoning_storage = {};
let generate_storage = {};
let title_ids_storage = {};
let image_storage = {};
let is_demo = false;
let wakeLock = null;
let countTokensEnabled = true;
@ -77,6 +80,8 @@ if (window.markdownit) {
.replaceAll('<code>', '<code class="language-plaintext">')
.replaceAll('&lt;i class=&quot;', '<i class="')
.replaceAll('&quot;&gt;&lt;/i&gt;', '"></i>')
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" src="')
.replaceAll('&quot;&gt;&lt;/iframe&gt;', `?enablejsapi=1&origin=${new URL(location.href).origin}` + '"></iframe>')
}
}
@ -426,20 +431,6 @@ const handle_ask = async (do_ask_gpt = true) => {
let message_index = await add_message(window.conversation_id, "user", message);
let message_id = get_message_id();
let images = [];
if (do_ask_gpt) {
if (imageInput.dataset.objects) {
imageInput.dataset.objects.split(" ").forEach((object)=>URL.revokeObjectURL(object))
delete imageInput.dataset.objects;
}
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput
if (input.files.length > 0) {
for (const file of input.files) {
images.push(URL.createObjectURL(file));
}
imageInput.dataset.objects = images.join(" ");
}
}
const message_el = document.createElement("div");
message_el.classList.add("message");
message_el.dataset.index = message_index;
@ -452,7 +443,6 @@ const handle_ask = async (do_ask_gpt = true) => {
<div class="content" id="user_${message_id}">
<div class="content_inner">
${markdown_render(message)}
${images.map((object)=>`<img src="${object}" alt="Image upload">`).join("")}
</div>
<div class="count">
${countTokensEnabled ? count_words_and_tokens(message, get_selected_model()?.value) : ""}
@ -937,8 +927,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
html = markdown_render(message_storage[message_id]);
content_map.inner.innerHTML = html;
highlight(content_map.inner);
if (imageInput) imageInput.value = "";
if (cameraInput) cameraInput.value = "";
}
if (message_storage[message_id]) {
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
@ -1032,8 +1020,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} else {
api_key = get_api_key_by_provider(provider);
}
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
const files = input && input.files.length > 0 ? input.files : null;
const download_images = document.getElementById("download_images")?.checked;
let api_base;
if (provider == "Custom") {
@ -1056,7 +1042,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
api_key: api_key,
api_base: api_base,
ignored: ignored,
}, files, message_id, scroll, finish_message);
}, Object.values(image_storage), message_id, scroll, finish_message);
} catch (e) {
console.error(e);
if (e.name != "AbortError") {
@ -1948,7 +1934,7 @@ async function on_load() {
chatPrompt.value = document.getElementById("systemPrompt")?.value || "";
say_hello();
} else {
load_conversation(window.conversation_id);
//load_conversation(window.conversation_id);
}
load_conversations();
}
@ -2249,12 +2235,51 @@ async function load_version() {
setTimeout(load_version, 1000 * 60 * 60); // 1 hour
}
[imageInput, cameraInput].forEach((el) => {
el.addEventListener('click', async () => {
el.value = '';
if (imageInput.dataset.objects) {
imageInput.dataset.objects.split(" ").forEach((object) => URL.revokeObjectURL(object));
delete imageInput.dataset.objects
function renderMediaSelect() {
const oldImages = mediaSelect.querySelectorAll("a:has(img)");
oldImages.forEach((el)=>el.remove());
Object.entries(image_storage).forEach(([object_url, file]) => {
const link = document.createElement("a");
link.title = file.name;
const img = document.createElement("img");
img.src = object_url;
img.onclick = () => {
img.remove();
delete image_storage[object_url];
URL.revokeObjectURL(object_url)
}
img.onload = () => {
link.title += `\n${img.naturalWidth}x${img.naturalHeight}`;
};
link.appendChild(img);
mediaSelect.appendChild(link);
});
}
imageInput.onclick = () => {
mediaSelect.classList.toggle("hidden");
}
mediaSelect.querySelector(".close").onclick = () => {
if (Object.values(image_storage).length) {
for (key in image_storage) {
URL.revokeObjectURL(key);
}
image_storage = {};
renderMediaSelect();
} else {
mediaSelect.classList.add("hidden");
}
}
[imageSelect, cameraInput].forEach((el) => {
el.addEventListener('change', async () => {
if (el.files.length) {
Array.from(el.files).forEach((file) => {
image_storage[URL.createObjectURL(file)] = file;
});
el.value = "";
renderMediaSelect();
}
});
});
@ -2270,7 +2295,7 @@ cameraInput?.addEventListener("click", (e) => {
}
});
imageInput?.addEventListener("click", (e) => {
imageSelect?.addEventListener("click", (e) => {
if (window?.pywebview) {
e.preventDefault();
pywebview.api.choose_image();

View file

@ -883,8 +883,8 @@ demo_models = {
qwq_32b.name: [qwq_32b, [HuggingFace]],
llama_3_3_70b.name: [llama_3_3_70b, [HuggingFace]],
sd_3_5.name: [sd_3_5, [HuggingSpace, HuggingFace]],
flux_dev.name: [flux_dev, [PollinationsImage, HuggingSpace, HuggingFace, G4F]],
flux_schnell.name: [flux_schnell, [PollinationsImage, HuggingFace, HuggingSpace, G4F]],
flux_dev.name: [flux_dev, [PollinationsImage, HuggingFace, HuggingSpace]],
flux_schnell.name: [flux_schnell, [PollinationsImage, HuggingFace, HuggingSpace]],
}
# Create a list of all models and his providers

View file

@ -41,11 +41,10 @@ async def async_generator_to_list(generator: AsyncIterator) -> list:
return [item async for item in generator]
def to_sync_generator(generator: AsyncIterator, stream: bool = True) -> Iterator:
loop = get_running_loop(check_nested=False)
if not stream:
yield from asyncio.run(async_generator_to_list(generator))
return
loop = get_running_loop(check_nested=False)
new_loop = False
if loop is None:
loop = asyncio.new_event_loop()

View file

@ -19,8 +19,7 @@ def quote_url(url: str) -> str:
def quote_title(title: str) -> str:
if title:
title = " ".join(title.split())
return title.replace('[', '').replace(']', '')
return " ".join(title.split())
return ""
def format_link(url: str, title: str = None) -> str:
@ -161,11 +160,21 @@ class Sources(ResponseType):
self.list.append(source)
def __str__(self) -> str:
return "\n\n" + ("\n".join([
f"{idx+1}. {format_link(link['url'], link.get('title', None))}"
return "\n\n\n\n" + ("\n>\n".join([
f"> [{idx}] {format_link(link['url'], link.get('title', None))}"
for idx, link in enumerate(self.list)
]))
class YouTube(ResponseType):
def __init__(self, ids: list[str]) -> None:
self.ids = ids
def __str__(self) -> str:
return "\n\n" + ("\n".join([
f'<iframe type="text/html" src="https://www.youtube.com/embed/{id}"></iframe>'
for id in self.ids
]))
class BaseConversation(ResponseType):
def __str__(self) -> str:
return ""

View file

@ -10,7 +10,7 @@ from typing import Optional, Callable, AsyncIterator
from ..typing import Messages
from ..providers.helper import filter_none
from ..providers.asyncio import to_async_iterator
from ..providers.response import Reasoning
from ..providers.response import Reasoning, FinishReason
from ..providers.types import ProviderType
from ..cookies import get_cookies_dir
from .web_search import do_search, get_search_message
@ -38,11 +38,12 @@ def get_api_key_file(cls) -> Path:
async def async_iter_run_tools(provider: ProviderType, model: str, messages, tool_calls: Optional[list] = None, **kwargs):
# Handle web_search from kwargs
web_search = kwargs.get('web_search')
sources = None
if web_search:
try:
messages = messages.copy()
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
messages[-1]["content"], sources = await do_search(messages[-1]["content"], web_search)
except Exception as e:
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Keep web_search in kwargs for provider native support
@ -88,6 +89,8 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
async for chunk in response:
yield chunk
if sources is not None:
yield sources
def process_thinking_chunk(chunk: str, start_time: float = 0) -> tuple[float, list]:
"""Process a thinking chunk and return timing and results."""
@ -144,11 +147,12 @@ def iter_run_tools(
) -> AsyncIterator:
# Handle web_search from kwargs
web_search = kwargs.get('web_search')
sources = None
if web_search:
try:
messages = messages.copy()
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
messages[-1]["content"], sources = asyncio.run(do_search(messages[-1]["content"], web_search))
except Exception as e:
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Keep web_search in kwargs for provider native support
@ -198,6 +202,12 @@ def iter_run_tools(
thinking_start_time = 0
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
if isinstance(chunk, FinishReason):
if sources is not None:
yield sources
sources = None
yield chunk
continue
if not isinstance(chunk, str):
yield chunk
continue
@ -206,3 +216,6 @@ def iter_run_tools(
for result in results:
yield result
if sources is not None:
yield sources

View file

@ -24,7 +24,7 @@ except:
from typing import Iterator
from ..cookies import get_cookies_dir
from ..providers.response import format_link
from ..providers.response import format_link, JsonMixin, Sources
from ..errors import MissingRequirementsError
from .. import debug
@ -33,11 +33,18 @@ Using the provided web search results, to write a comprehensive reply to the use
Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)
"""
class SearchResults():
class SearchResults(JsonMixin):
def __init__(self, results: list, used_words: int):
self.results = results
self.used_words = used_words
@classmethod
def from_dict(cls, data: dict):
return cls(
[SearchResultEntry(**item) for item in data["results"]],
data["used_words"]
)
def __iter__(self):
yield from self.results
@ -57,7 +64,17 @@ class SearchResults():
def __len__(self) -> int:
return len(self.results)
class SearchResultEntry():
def get_sources(self) -> Sources:
return Sources([{"url": result.url, "title": result.title} for result in self.results])
def get_dict(self):
return {
"results": [result.get_dict() for result in self.results],
"used_words": self.used_words
}
class SearchResultEntry(JsonMixin):
def __init__(self, title: str, url: str, snippet: str, text: str = None):
self.title = title
self.url = url
@ -191,11 +208,11 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
return SearchResults(formatted_results, used_words)
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> tuple[str, Sources]:
if instructions and instructions in prompt:
return prompt # We have already added search results
return prompt, None # We have already added search results
if prompt.startswith("##") and query is None:
return prompt # We have no search query
return prompt, None # We have no search query
if query is None:
query = prompt.strip().splitlines()[0] # Use the first line as the search query
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")
@ -203,14 +220,19 @@ async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search" / f"{datetime.date.today()}"
bucket_dir.mkdir(parents=True, exist_ok=True)
cache_file = bucket_dir / f"{quote_plus(query[:20])}.{md5_hash}.cache"
search_results = None
if cache_file.exists():
with cache_file.open("r") as f:
search_results = f.read()
else:
try:
search_results = SearchResults.from_dict(json.loads(search_results))
except json.JSONDecodeError:
search_results = None
if search_results is None:
search_results = await search(query, **kwargs)
if search_results.results:
with cache_file.open("wb") as f:
f.write(str(search_results).encode(errors="replace"))
with cache_file.open("w") as f:
f.write(json.dumps(search_results.get_dict()))
if instructions:
new_prompt = f"""
{search_results}
@ -227,13 +249,12 @@ User request:
{prompt}
"""
debug.log(f"Web search: '{query.strip()[:50]}...'")
if isinstance(search_results, SearchResults):
debug.log(f"with {len(search_results.results)} Results {search_results.used_words} Words")
return new_prompt
return new_prompt, search_results.get_sources()
def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str:
try:
return asyncio.run(do_search(prompt, **kwargs))
return asyncio.run(do_search(prompt, **kwargs))[0]
except (DuckDuckGoSearchException, MissingRequirementsError) as e:
if raise_search_exceptions:
raise e