Merge pull request #2680 from xtekky/2.1Feb

Improve select custom model in UI 
Updates for the response of the BackendApi
Update of the demo model list
Improve web search tool
Moved copy_images to /image
This commit is contained in:
H Lohaus 2025-02-03 21:28:17 +01:00 committed by GitHub
commit cd84c49f82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 461 additions and 165 deletions

View file

@ -1,5 +1,5 @@
from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider
from g4f.image import ImageResponse
from g4f.providers.response import ImageResponse
from g4f.errors import MissingAuthError
class ProviderMock(AbstractProvider):

View file

@ -12,10 +12,10 @@ from datetime import datetime, timezone
from ..typing import AsyncResult, Messages, ImagesType
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse, to_data_uri
from ..image import to_data_uri
from ..cookies import get_cookies_dir
from .helper import format_prompt, format_image_prompt
from ..providers.response import JsonConversation, Reasoning
from ..providers.response import JsonConversation, ImageResponse
class Conversation(JsonConversation):
validated_value: str = None

View file

@ -24,10 +24,10 @@ from .openai.har_file import get_headers, get_har_files
from ..typing import CreateResult, Messages, ImagesType
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
from ..requests.raise_for_status import raise_for_status
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
from ..providers.asyncio import get_running_loop
from ..requests import get_nodriver
from ..image import ImageResponse, to_bytes, is_accepted_format
from ..image import to_bytes, is_accepted_format
from .helper import get_last_user_message
from .. import debug

View file

@ -5,7 +5,7 @@ import time
import asyncio
from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..providers.response import ImageResponse
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin

View file

@ -73,7 +73,6 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
}
await ws.send_str("42" + json.dumps(["perplexity_labs", message_data]))
last_message = 0
is_thinking = False
while True:
message = await ws.receive_str()
if message == "2":

View file

@ -6,7 +6,7 @@ import random
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse
from ..providers.response import ImageResponse
class Prodia(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://app.prodia.com"

View file

@ -7,8 +7,9 @@ import uuid
from ..typing import AsyncResult, Messages, ImageType, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt
from ..image import ImageResponse, ImagePreview, EXTENSIONS_MAP, to_bytes, is_accepted_format
from ..image import EXTENSIONS_MAP, to_bytes, is_accepted_format
from ..requests import StreamSession, FormData, raise_for_status, get_nodriver
from ..providers.response import ImagePreview, ImageResponse
from ..cookies import get_cookies
from ..errors import MissingRequirementsError, ResponseError
from .. import debug

View file

@ -48,18 +48,19 @@ class HuggingFaceAPI(OpenaiTemplate):
if model in cls.model_aliases:
model_name = cls.model_aliases[model]
api_base = f"https://api-inference.huggingface.co/models/{model_name}/v1"
if images is not None:
async with StreamSession(
timeout=30,
) as session:
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
await raise_for_status(response)
model_data = await response.json()
pipeline_tag = model_data.get("pipeline_tag")
if pipeline_tag != "image-text-to-text":
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag={pipeline_tag}")
async with StreamSession(
timeout=30,
) as session:
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
await raise_for_status(response)
model_data = await response.json()
pipeline_tag = model_data.get("pipeline_tag")
if images is None and pipeline_tag not in ("text-generation", "image-text-to-text"):
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
elif pipeline_tag != "image-text-to-text":
raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
start = calculate_lenght(messages)
if start > max_inputs_lenght:
if len(messages) > 6:

View file

@ -9,14 +9,14 @@ from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason
from ...image import ImageResponse
from ...providers.response import FinishReason, ImageResponse
from ..helper import format_image_prompt
from .models import default_model, default_image_model, model_aliases, fallback_models
from ... import debug
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
parent = "HuggingFace"
working = True
default_model = default_model

View file

@ -4,7 +4,7 @@ import random
from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse
from ...errors import ModelNotSupportedError
from ...errors import ModelNotSupportedError, MissingAuthError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .HuggingChat import HuggingChat
from .HuggingFaceAPI import HuggingFaceAPI
@ -60,6 +60,6 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
try:
async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
yield chunk
except ModelNotSupportedError:
except (ModelNotSupportedError, MissingAuthError):
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
yield chunk

View file

@ -4,7 +4,7 @@ import json
from aiohttp import ClientSession
from ...typing import AsyncResult, Messages
from ...image import ImageResponse, ImagePreview
from ...providers.response import ImageResponse, ImagePreview
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt

View file

@ -4,7 +4,7 @@ from aiohttp import ClientSession
import json
from ...typing import AsyncResult, Messages
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ...errors import ResponseError
from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin

View file

@ -4,7 +4,7 @@ import json
from aiohttp import ClientSession
from ...typing import AsyncResult, Messages
from ...image import ImageResponse, ImagePreview
from ...providers.response import ImageResponse, ImagePreview
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt

View file

@ -4,7 +4,7 @@ from aiohttp import ClientSession
import json
from ...typing import AsyncResult, Messages
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ...errors import ResponseError
from ...requests.raise_for_status import raise_for_status
from ..helper import format_image_prompt

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from ...cookies import get_cookies
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ...errors import MissingAuthError
from ...typing import AsyncResult, Messages, Cookies
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import requests
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ..template import OpenaiTemplate
from ..helper import format_image_prompt

View file

@ -48,12 +48,13 @@ try:
raise MissingAuthError()
response.raise_for_status()
return response.json()
has_dsk = True
except ImportError:
pass
has_dsk = False
class DeepSeekAPI(AsyncAuthedProvider):
url = "https://chat.deepseek.com"
working = False
working = has_dsk
needs_auth = True
use_nodriver = True
_access_token = None
@ -91,6 +92,7 @@ class DeepSeekAPI(AsyncAuthedProvider):
if conversation is None:
chat_id = api.create_chat_session()
conversation = JsonConversation(chat_id=chat_id)
yield conversation
is_thinking = 0
for chunk in api.chat_completion(

View file

@ -19,12 +19,12 @@ from ... import debug
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies
from ...providers.response import JsonConversation, SynthesizeData, RequestLogin
from ...providers.response import JsonConversation, SynthesizeData, RequestLogin, ImageResponse
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...requests import get_nodriver
from ...errors import MissingAuthError
from ...image import ImageResponse, to_bytes
from ...image import to_bytes
from ..helper import get_last_user_message
from ... import debug

View file

@ -10,7 +10,7 @@ from aiohttp import ClientSession, BaseConnector
from ...typing import AsyncResult, Messages, Cookies
from ...requests import raise_for_status, DEFAULT_HEADERS
from ...image import ImageResponse, ImagePreview
from ...providers.response import ImageResponse, ImagePreview
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_connector, format_cookies

View file

@ -6,7 +6,7 @@ import random
import asyncio
import json
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ...errors import MissingRequirementsError, NoValidHarFileError
from ...typing import AsyncResult, Messages
from ...requests.raise_for_status import raise_for_status

View file

@ -22,9 +22,9 @@ from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...requests import get_nodriver
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse
from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters, Reasoning
from ..helper import format_cookies
from ..openai.models import default_model, default_image_model, models, image_models, text_models
@ -526,7 +526,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
c = m.get("content", {})
if c.get("content_type") == "text" and m.get("author", {}).get("role") == "tool" and "initial_text" in m.get("metadata", {}):
fields.is_thinking = True
yield Reasoning(status=c.get("metadata", {}).get("initial_text"))
yield Reasoning(status=m.get("metadata", {}).get("initial_text"))
if c.get("content_type") == "multimodal_text":
generated_images = []
for element in c.get("parts"):

View file

@ -5,7 +5,7 @@ import base64
from aiohttp import ClientSession
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ..helper import format_prompt
class AiChats(AsyncGeneratorProvider, ProviderModelMixin):

View file

@ -6,8 +6,7 @@ from aiohttp import ClientSession
from typing import List
from ...typing import AsyncResult, Messages
from ...image import ImageResponse
from ...providers.response import FinishReason, Usage
from ...providers.response import ImageResponse, FinishReason, Usage
from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin

View file

@ -5,7 +5,7 @@ import uuid
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...image import ImageResponse
from ...providers.response import ImageResponse
from ...requests import StreamSession, raise_for_status
from ...errors import ResponseStatusError

View file

@ -9,7 +9,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...requests.aiohttp import get_connector
from ...requests.raise_for_status import raise_for_status
from ..helper import format_prompt
from ...image import ImageResponse
from ...providers.response import ImageResponse
class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com"

View file

@ -10,6 +10,7 @@ from ... import debug
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
ssl = None
headers = {}
@classmethod
async def create_async_generator(
@ -21,7 +22,7 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
debug.log(f"{cls.__name__}: {api_key}")
async with StreamSession(
headers={"Accept": "text/event-stream"},
headers={"Accept": "text/event-stream", **cls.headers},
) as session:
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
"model": model,

View file

@ -39,7 +39,8 @@ import g4f.debug
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none
from g4f.image import is_data_uri_an_image, images_dir, copy_images
from g4f.image import is_data_uri_an_image
from g4f.image.copy_images import images_dir, copy_images
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__

View file

@ -9,10 +9,10 @@ import aiohttp
import base64
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image import ImageResponse, copy_images
from ..image.copy_images import copy_images
from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
from ..providers.response import ResponseType, ImageResponse, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
from ..errors import NoImageResponseError
from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator

View file

@ -165,7 +165,7 @@
</div>
<div class="field box">
<label for="systemPrompt" class="label">System prompt</label>
<textarea id="systemPrompt" placeholder="You are a helpful assistant." data-value="If you need to generate images, you can use the following format: ![keywords](/generate/filename.jpg). This will enable the use of an image generation tool."></textarea>
<textarea id="systemPrompt" placeholder="You are a helpful assistant." data-example="If you need to generate images, you can use the following format: ![keywords](/generate/filename.jpg). This will enable the use of an image generation tool."></textarea>
</div>
<div class="field box">
<label for="message-input-height" class="label" title="">Input max. height</label>

View file

@ -36,7 +36,8 @@ let title_storage = {};
let parameters_storage = {};
let finish_storage = {};
let usage_storage = {};
let reasoning_storage = {}
let reasoning_storage = {};
let generate_storage = {};
let is_demo = false;
messageInput.addEventListener("blur", () => {
@ -96,9 +97,13 @@ function filter_message(text) {
}
function filter_message_content(text) {
return text.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "")
}
function filter_message_image(text) {
return text.replaceAll(
/\/\]\(\/generate\//gm, "/](/images/"
).replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "")
/\]\(\/generate\//gm, "](/images/"
)
}
function fallback_clipboard (text) {
@ -204,6 +209,7 @@ function register_message_images() {
el.onerror = () => {
let indexCommand;
if ((indexCommand = el.src.indexOf("/generate/")) >= 0) {
generate_storage[window.conversation_id] = true;
indexCommand = indexCommand + "/generate/".length + 1;
let newPath = el.src.substring(indexCommand)
let filename = newPath.replace(/(?:\?.+?|$)/, "");
@ -973,7 +979,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
await add_message(
window.conversation_id,
"assistant",
final_message,
filter_message_image(final_message),
message_provider,
message_index,
synthesize_storage[message_id],
@ -999,7 +1005,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
delete controller_storage[message_id];
}
// Reload conversation if no error
if (!error_storage[message_id]) {
if (!error_storage[message_id] && !generate_storage[window.conversation_id]) {
await safe_load_conversation(window.conversation_id, scroll);
}
let cursorDiv = message_el.querySelector(".cursor");
@ -1022,14 +1028,23 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} else {
api_key = get_api_key_by_provider(provider);
}
if (is_demo && !api_key && provider != "Custom") {
if (is_demo && !api_key) {
api_key = localStorage.getItem("HuggingFace-api_key");
}
if (is_demo && !api_key) {
location.href = "/";
return;
}
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
const files = input && input.files.length > 0 ? input.files : null;
const download_images = document.getElementById("download_images")?.checked;
const api_base = provider == "Custom" ? document.getElementById(`${provider}-api_base`).value : null;
let api_base;
if (provider == "Custom") {
api_base = document.getElementById("api_base")?.value;
if (!api_base) {
provider = "";
}
}
const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value);
await api("conversation", {
id: message_id,
@ -1886,6 +1901,10 @@ async function on_load() {
load_conversation(window.conversation_id);
} else {
chatPrompt.value = document.getElementById("systemPrompt")?.value || "";
example = document.getElementById("systemPrompt")?.dataset.example || ""
if (chatPrompt.value == example) {
messageInput.value = "";
}
let chat_url = new URL(window.location.href)
let chat_params = new URLSearchParams(chat_url.search);
if (chat_params.get("prompt")) {
@ -2493,19 +2512,23 @@ async function load_provider_models(provider=null) {
if (!custom_model.value) {
custom_model.classList.add("hidden");
}
if (provider == "Custom Model" || custom_model.value) {
if (provider.startsWith("Custom") || custom_model.value) {
modelProvider.classList.add("hidden");
modelSelect.classList.add("hidden");
document.getElementById("model3").classList.remove("hidden");
custom_model.classList.remove("hidden");
return;
}
modelProvider.innerHTML = '';
modelProvider.name = `model[${provider}]`;
if (!provider) {
modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden");
document.getElementById("model3").value = "";
document.getElementById("model3").classList.remove("hidden");
if (custom_model.value) {
modelSelect.classList.add("hidden");
custom_model.classList.remove("hidden");
} else {
modelSelect.classList.remove("hidden");
custom_model.classList.add("hidden");
}
return;
}
const models = await api('models', provider);
@ -2531,11 +2554,9 @@ async function load_provider_models(provider=null) {
modelProvider.value = value;
}
modelProvider.selectedIndex = defaultIndex;
} else if (custom_model.value) {
modelSelect.classList.add("hidden");
} else {
modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden");
custom_model.classList.remove("hidden")
}
};
providerSelect.addEventListener("change", () => {

View file

@ -8,7 +8,7 @@ from flask import send_from_directory
from inspect import signature
from ...errors import VersionNotFoundError
from ...image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from ...image.copy_images import copy_images, ensure_images_dir, images_dir
from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin
@ -182,7 +182,9 @@ class Api:
elif isinstance(chunk, Exception):
logger.exception(chunk)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, (PreviewResponse, ImagePreview)):
elif isinstance(chunk, PreviewResponse):
yield self._format_json("preview", chunk.to_string())
elif isinstance(chunk, ImagePreview):
yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
elif isinstance(chunk, ImageResponse):
images = chunk
@ -207,6 +209,8 @@ class Api:
yield self._format_json("reasoning", **chunk.get_dict())
elif isinstance(chunk, DebugResponse):
yield self._format_json("log", chunk.log)
elif isinstance(chunk, RawResponse):
yield self._format_json(chunk.type, **chunk.get_dict())
else:
yield self._format_json("content", str(chunk))
if debug.logs:
@ -215,6 +219,8 @@ class Api:
debug.logs = []
except Exception as e:
logger.exception(e)
if debug.logging:
debug.log_handler(get_error_message(e))
if debug.logs:
for log in debug.logs:
yield self._format_json("log", str(log))
@ -222,7 +228,7 @@ class Api:
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _format_json(self, response_type: str, content = None, **kwargs):
if content is not None:
if content is not None and isinstance(response_type, str):
return {
'type': response_type,
response_type: content,

View file

@ -140,9 +140,9 @@ class Backend_Api(Api):
if model != "default" and model in models.demo_models:
json_data["provider"] = random.choice(models.demo_models[model][1])
else:
json_data["model"] = models.demo_models["default"][0].name
if not model or model == "default":
json_data["model"] = models.demo_models["default"][0].name
json_data["provider"] = random.choice(models.demo_models["default"][1])
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
return self.app.response_class(
self._create_response_stream(

View file

@ -3,15 +3,10 @@ from __future__ import annotations
import os
import re
import io
import time
import uuid
import base64
import asyncio
import hashlib
from urllib.parse import quote_plus
from io import BytesIO
from pathlib import Path
from aiohttp import ClientSession, ClientError
try:
from PIL.Image import open as open_image, new as new_image
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
@ -21,9 +16,7 @@ except ImportError:
from .typing import ImageType, Union, Image, Optional, Cookies
from .errors import MissingRequirementsError
from .providers.response import ImageResponse, ImagePreview
from .requests.aiohttp import get_connector
from . import debug
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
@ -237,68 +230,6 @@ def to_data_uri(image: ImageType) -> str:
return f"data:{is_accepted_format(data)};base64,{data_base64}"
return image
# Function to ensure the images directory exists
def ensure_images_dir():
os.makedirs(images_dir, exist_ok=True)
def get_image_extension(image: str) -> str:
match = re.search(r"\.(?:jpe?g|png|webp)", image)
if match:
return match.group(0)
return ".jpg"
async def copy_images(
images: list[str],
cookies: Optional[Cookies] = None,
headers: Optional[dict] = None,
proxy: Optional[str] = None,
alt: str = None,
add_url: bool = True,
target: str = None,
ssl: bool = None
) -> list[str]:
if add_url:
add_url = not cookies
ensure_images_dir()
async with ClientSession(
connector=get_connector(proxy=proxy),
cookies=cookies,
headers=headers,
) as session:
async def copy_image(image: str, target: str = None) -> str:
if target is None or len(images) > 1:
hash = hashlib.sha256(image.encode()).hexdigest()
target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4())
target = f"{int(time.time())}_{target}{get_image_extension(image)}"
target = os.path.join(images_dir, target)
try:
if image.startswith("data:"):
with open(target, "wb") as f:
f.write(extract_data_uri(image))
else:
try:
async with session.get(image, ssl=ssl) as response:
response.raise_for_status()
with open(target, "wb") as f:
async for chunk in response.content.iter_chunked(4096):
f.write(chunk)
except ClientError as e:
debug.log(f"copy_images failed: {e.__class__.__name__}: {e}")
return image
if "." not in target:
with open(target, "rb") as f:
extension = is_accepted_format(f.read(12)).split("/")[-1]
extension = "jpg" if extension == "jpeg" else extension
new_target = f"{target}.{extension}"
os.rename(target, new_target)
target = new_target
finally:
if "." not in target and os.path.exists(target):
os.unlink(target)
return f"/images/{os.path.basename(target)}{'?url=' + image if add_url and not image.startswith('data:') else ''}"
return await asyncio.gather(*[copy_image(image, target) for image in images])
class ImageDataResponse():
def __init__(
self,

247
g4f/image/__init__.py Normal file
View file

@ -0,0 +1,247 @@
from __future__ import annotations
import os
import re
import io
import base64
from io import BytesIO
from pathlib import Path
try:
from PIL.Image import open as open_image, new as new_image
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
has_requirements = True
except ImportError:
has_requirements = False
from ..typing import ImageType, Union, Image
from ..errors import MissingRequirementsError
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
EXTENSIONS_MAP: dict[str, str] = {
"image/png": "png",
"image/jpeg": "jpg",
"image/gif": "gif",
"image/webp": "webp",
}
def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
Converts the input image to a PIL Image object.
Args:
image (Union[str, bytes, Image]): The input image.
Returns:
Image: The converted PIL Image object.
"""
if not has_requirements:
raise MissingRequirementsError('Install "pillow" package for images')
if isinstance(image, str) and image.startswith("data:"):
is_data_uri_an_image(image)
image = extract_data_uri(image)
if is_svg:
try:
import cairosvg
except ImportError:
raise MissingRequirementsError('Install "cairosvg" package for svg images')
if not isinstance(image, bytes):
image = image.read()
buffer = BytesIO()
cairosvg.svg2png(image, write_to=buffer)
return open_image(buffer)
if isinstance(image, bytes):
is_accepted_format(image)
return open_image(BytesIO(image))
elif not isinstance(image, Image):
image = open_image(image)
image.load()
return image
return image
def is_allowed_extension(filename: str) -> bool:
"""
Checks if the given filename has an allowed extension.
Args:
filename (str): The filename to check.
Returns:
bool: True if the extension is allowed, False otherwise.
"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_image(data_uri: str) -> bool:
"""
Checks if the given data URI represents an image.
Args:
data_uri (str): The data URI to check.
Raises:
ValueError: If the data URI is invalid or the image format is not allowed.
"""
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
# Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> str:
"""
Checks if the given binary data represents an image with an accepted format.
Args:
binary_data (bytes): The binary data to check.
Raises:
ValueError: If the image format is not allowed.
"""
if binary_data.startswith(b'\xFF\xD8\xFF'):
return "image/jpeg"
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
return "image/png"
elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
return "image/gif"
elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
return "image/jpeg"
elif binary_data.startswith(b'\xFF\xD8'):
return "image/jpeg"
elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
return "image/webp"
else:
raise ValueError("Invalid image format (from magic code).")
def extract_data_uri(data_uri: str) -> bytes:
"""
Extracts the binary data from the given data URI.
Args:
data_uri (str): The data URI.
Returns:
bytes: The extracted binary data.
"""
data = data_uri.split(",")[-1]
data = base64.b64decode(data)
return data
def get_orientation(image: Image) -> int:
"""
Gets the orientation of the given image.
Args:
image (Image): The image.
Returns:
int: The orientation value.
"""
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
if orientation is not None:
return orientation
def process_image(image: Image, new_width: int, new_height: int) -> Image:
"""
Processes the given image by adjusting its orientation and resizing it.
Args:
image (Image): The image to process.
new_width (int): The new width of the image.
new_height (int): The new height of the image.
Returns:
Image: The processed image.
"""
# Fix orientation
orientation = get_orientation(image)
if orientation:
if orientation > 4:
image = image.transpose(FLIP_LEFT_RIGHT)
if orientation in [3, 4]:
image = image.transpose(ROTATE_180)
if orientation in [5, 6]:
image = image.transpose(ROTATE_270)
if orientation in [7, 8]:
image = image.transpose(ROTATE_90)
# Resize image
image.thumbnail((new_width, new_height))
# Remove transparency
if image.mode == "RGBA":
image.load()
white = new_image('RGB', image.size, (255, 255, 255))
white.paste(image, mask=image.split()[-1])
return white
# Convert to RGB for jpg format
elif image.mode != "RGB":
image = image.convert("RGB")
return image
def to_bytes(image: ImageType) -> bytes:
"""
Converts the given image to bytes.
Args:
image (ImageType): The image to convert.
Returns:
bytes: The image as bytes.
"""
if isinstance(image, bytes):
return image
elif isinstance(image, str) and image.startswith("data:"):
is_data_uri_an_image(image)
return extract_data_uri(image)
elif isinstance(image, Image):
bytes_io = BytesIO()
image.save(bytes_io, image.format)
image.seek(0)
return bytes_io.getvalue()
elif isinstance(image, (str, os.PathLike)):
return Path(image).read_bytes()
elif isinstance(image, Path):
return image.read_bytes()
else:
try:
image.seek(0)
except (AttributeError, io.UnsupportedOperation):
pass
return image.read()
def to_data_uri(image: ImageType) -> str:
if not isinstance(image, str):
data = to_bytes(image)
data_base64 = base64.b64encode(data).decode()
return f"data:{is_accepted_format(data)};base64,{data_base64}"
return image
class ImageDataResponse():
def __init__(
self,
images: Union[str, list],
alt: str,
):
self.images = images
self.alt = alt
def get_list(self) -> list[str]:
return [self.images] if isinstance(self.images, str) else self.images
class ImageRequest:
def __init__(
self,
options: dict = {}
):
self.options = options
def get(self, key: str):
return self.options.get(key)

84
g4f/image/copy_images.py Normal file
View file

@ -0,0 +1,84 @@
from __future__ import annotations
import os
import time
import uuid
import asyncio
import hashlib
import re
from urllib.parse import quote_plus
from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector
from ..Provider.template import BackendApi
from . import is_accepted_format, extract_data_uri
from .. import debug
# Define the directory for generated images
images_dir = "./generated_images"
def get_image_extension(image: str) -> str:
match = re.search(r"\.(?:jpe?g|png|webp)", image)
if match:
return match.group(0)
return ".jpg"
# Function to ensure the images directory exists
def ensure_images_dir():
os.makedirs(images_dir, exist_ok=True)
async def copy_images(
images: list[str],
cookies: Optional[Cookies] = None,
headers: Optional[dict] = None,
proxy: Optional[str] = None,
alt: str = None,
add_url: bool = True,
target: str = None,
ssl: bool = None
) -> list[str]:
if add_url:
add_url = not cookies
ensure_images_dir()
async with ClientSession(
connector=get_connector(proxy=proxy),
cookies=cookies,
headers=headers,
) as session:
async def copy_image(image: str, target: str = None, headers: dict = headers, ssl: bool = ssl) -> str:
if target is None or len(images) > 1:
hash = hashlib.sha256(image.encode()).hexdigest()
target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4())
target = f"{int(time.time())}_{target}{get_image_extension(image)}"
target = os.path.join(images_dir, target)
try:
if image.startswith("data:"):
with open(target, "wb") as f:
f.write(extract_data_uri(image))
else:
try:
if BackendApi.working and image.startswith(BackendApi.url) and headers is None:
headers = BackendApi.headers
ssl = BackendApi.ssl
async with session.get(image, ssl=ssl, headers=headers) as response:
response.raise_for_status()
with open(target, "wb") as f:
async for chunk in response.content.iter_chunked(4096):
f.write(chunk)
except ClientError as e:
debug.log(f"copy_images failed: {e.__class__.__name__}: {e}")
return image
if "." not in target:
with open(target, "rb") as f:
extension = is_accepted_format(f.read(12)).split("/")[-1]
extension = "jpg" if extension == "jpeg" else extension
new_target = f"{target}.{extension}"
os.rename(target, new_target)
target = new_target
finally:
if "." not in target and os.path.exists(target):
os.unlink(target)
return f"/images/{os.path.basename(target)}{'?url=' + image if add_url and not image.startswith('data:') else ''}"
return await asyncio.gather(*[copy_image(image, target) for image in images])

View file

@ -765,11 +765,11 @@ class ModelUtils:
demo_models = {
gpt_4o.name: [gpt_4o, [PollinationsAI, Blackbox]],
"default": [llama_3_2_11b, [HuggingFaceAPI]],
"default": [llama_3_2_11b, [HuggingFace]],
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
qvq_72b.name: [qvq_72b, [HuggingSpace, HuggingFaceAPI]],
deepseek_r1.name: [deepseek_r1, [HuggingFace, HuggingFaceAPI]],
claude_3_haiku.name: [claude_3_haiku, [DDG, Jmuz]],
qvq_72b.name: [qvq_72b, [HuggingSpace]],
deepseek_r1.name: [deepseek_r1, [HuggingFace]],
claude_3_5_sonnet.name: [claude_3_5_sonnet, claude_3_5_sonnet.best_provider.providers],
command_r.name: [command_r, [HuggingSpace]],
command_r_plus.name: [command_r_plus, [HuggingSpace]],
command_r7b.name: [command_r7b, [HuggingSpace]],
@ -779,7 +779,7 @@ demo_models = {
llama_3_3_70b.name: [llama_3_3_70b, [HuggingFace]],
sd_3_5.name: [sd_3_5, [HuggingSpace, HuggingFace]],
flux_dev.name: [flux_dev, [HuggingSpace, HuggingFace]],
flux_schnell.name: [flux_schnell, [HuggingFace]],
flux_schnell.name: [flux_schnell, [HuggingFace, HuggingSpace, PollinationsAI]],
}
# Create a list of all models and his providers

View file

@ -6,7 +6,7 @@ import asyncio
from .. import debug
from ..typing import CreateResult, Messages
from .types import BaseProvider, ProviderType
from ..image import ImageResponse
from ..providers.response import ImageResponse
system_message = """
You can generate images, pictures, photos or img with the DALL-E 3 image generator.

View file

@ -137,7 +137,7 @@ async def get_nodriver(
timeout: int = 120,
browser_executable_path=None,
**kwargs
) -> Browser:
) -> tuple[Browser, callable]:
if not has_nodriver:
raise MissingRequirementsError('Install "nodriver" and "platformdirs" package | pip install -U nodriver platformdirs')
user_data_dir = user_config_dir(f"g4f-{user_data_dir}") if has_platformdirs else None

View file

@ -159,26 +159,27 @@ def iter_run_tools(
chunk = chunk.split("<think>", 1)
if len(chunk) > 0 and chunk[0]:
yield chunk[0]
yield Reasoning(None, "🤔 Is thinking...", is_thinking="<think>")
yield Reasoning(status="🤔 Is thinking...", is_thinking="<think>")
if chunk != "<think>":
if len(chunk) > 1 and chunk[1]:
yield Reasoning(chunk[1])
is_thinking = time.time()
if "</think>" in chunk:
if chunk != "<think>":
chunk = chunk.split("</think>", 1)
if len(chunk) > 0 and chunk[0]:
yield Reasoning(chunk[0])
is_thinking = time.time() - is_thinking
if is_thinking > 1:
yield Reasoning(None, f"Thought for {is_thinking:.2f}s", is_thinking="</think>")
else:
yield Reasoning(None, f"Finished", is_thinking="</think>")
if chunk != "<think>":
if len(chunk) > 1 and chunk[1]:
yield chunk[1]
is_thinking = 0
elif is_thinking:
yield Reasoning(chunk)
else:
yield chunk
if "</think>" in chunk:
if chunk != "<think>":
chunk = chunk.split("</think>", 1)
if len(chunk) > 0 and chunk[0]:
yield Reasoning(chunk[0])
is_thinking = time.time() - is_thinking if is_thinking > 0 else 0
if is_thinking > 1:
yield Reasoning(status=f"Thought for {is_thinking:.2f}s", is_thinking="</think>")
else:
yield Reasoning(status=f"Finished", is_thinking="</think>")
if chunk != "<think>":
if len(chunk) > 1 and chunk[1]:
yield chunk[1]
is_thinking = 0
elif is_thinking:
yield Reasoning(chunk)
else:
yield chunk

View file

@ -194,8 +194,10 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
if instructions and instructions in prompt:
return prompt # We have already added search results
if prompt.startswith("##") and query is None:
return prompt # We have no search query
if query is None:
query = spacy_get_keywords(prompt)
query = prompt.strip().splitlines()[0] # Use the first line as the search query
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")
md5_hash = hashlib.md5(json_bytes).hexdigest()
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search" / f"{datetime.date.today()}"