From 0638cbc17559583894a2564612a5083e1ce63d87 Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Mon, 3 Feb 2025 20:23:21 +0100 Subject: [PATCH 1/3] Improve select custom model in UI Updates for the response of the BackendApi Update of the demo model list Improve web search tool Moved copy_images to /image --- g4f/Provider/Blackbox.py | 4 +- g4f/Provider/Copilot.py | 4 +- g4f/Provider/ImageLabs.py | 2 +- g4f/Provider/PerplexityLabs.py | 1 - g4f/Provider/Prodia.py | 2 +- g4f/Provider/You.py | 3 +- g4f/Provider/hf/HuggingFaceAPI.py | 25 +- g4f/Provider/hf/HuggingFaceInference.py | 4 +- g4f/Provider/hf/__init__.py | 4 +- .../hf_space/BlackForestLabsFlux1Dev.py | 2 +- .../hf_space/BlackForestLabsFlux1Schnell.py | 2 +- .../hf_space/StableDiffusion35Large.py | 2 +- .../hf_space/VoodoohopFlux1Schnell.py | 2 +- g4f/Provider/needs_auth/BingCreateImages.py | 2 +- g4f/Provider/needs_auth/DeepInfra.py | 2 +- g4f/Provider/needs_auth/DeepSeekAPI.py | 6 +- g4f/Provider/needs_auth/Gemini.py | 4 +- g4f/Provider/needs_auth/MetaAI.py | 2 +- g4f/Provider/needs_auth/MicrosoftDesigner.py | 2 +- g4f/Provider/needs_auth/OpenaiChat.py | 6 +- g4f/Provider/not_working/AiChats.py | 2 +- g4f/api/__init__.py | 3 +- g4f/client/__init__.py | 4 +- g4f/gui/client/index.html | 2 +- g4f/gui/client/static/js/chat.v1.js | 51 ++-- g4f/gui/server/api.py | 12 +- g4f/gui/server/backend_api.py | 4 +- g4f/image.py | 69 ----- g4f/image/__init__.py | 247 ++++++++++++++++++ g4f/image/copy_images.py | 84 ++++++ g4f/models.py | 10 +- g4f/tools/run_tools.py | 37 +-- g4f/tools/web_search.py | 4 +- 33 files changed, 453 insertions(+), 157 deletions(-) create mode 100644 g4f/image/__init__.py create mode 100644 g4f/image/copy_images.py diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py index 61982eea..559ce7ba 100644 --- a/g4f/Provider/Blackbox.py +++ b/g4f/Provider/Blackbox.py @@ -12,10 +12,10 @@ from datetime import datetime, timezone from ..typing import AsyncResult, Messages, ImagesType from ..requests.raise_for_status import raise_for_status from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..image import ImageResponse, to_data_uri +from ..image import to_data_uri from ..cookies import get_cookies_dir from .helper import format_prompt, format_image_prompt -from ..providers.response import JsonConversation, Reasoning +from ..providers.response import JsonConversation, ImageResponse class Conversation(JsonConversation): validated_value: str = None diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py index f63b0563..68808384 100644 --- a/g4f/Provider/Copilot.py +++ b/g4f/Provider/Copilot.py @@ -24,10 +24,10 @@ from .openai.har_file import get_headers, get_har_files from ..typing import CreateResult, Messages, ImagesType from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError from ..requests.raise_for_status import raise_for_status -from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters +from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse from ..providers.asyncio import get_running_loop from ..requests import get_nodriver -from ..image import ImageResponse, to_bytes, is_accepted_format +from ..image import to_bytes, is_accepted_format from .helper import get_last_user_message from .. import debug diff --git a/g4f/Provider/ImageLabs.py b/g4f/Provider/ImageLabs.py index 7c8ed14b..5701e5ef 100644 --- a/g4f/Provider/ImageLabs.py +++ b/g4f/Provider/ImageLabs.py @@ -5,7 +5,7 @@ import time import asyncio from ..typing import AsyncResult, Messages -from ..image import ImageResponse +from ..providers.response import ImageResponse from .base_provider import AsyncGeneratorProvider, ProviderModelMixin diff --git a/g4f/Provider/PerplexityLabs.py b/g4f/Provider/PerplexityLabs.py index 1d06784d..0e703450 100644 --- a/g4f/Provider/PerplexityLabs.py +++ b/g4f/Provider/PerplexityLabs.py @@ -73,7 +73,6 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin): } await ws.send_str("42" + json.dumps(["perplexity_labs", message_data])) last_message = 0 - is_thinking = False while True: message = await ws.receive_str() if message == "2": diff --git a/g4f/Provider/Prodia.py b/g4f/Provider/Prodia.py index 76516e2f..912f0005 100644 --- a/g4f/Provider/Prodia.py +++ b/g4f/Provider/Prodia.py @@ -6,7 +6,7 @@ import random from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..image import ImageResponse +from ..providers.response import ImageResponse class Prodia(AsyncGeneratorProvider, ProviderModelMixin): url = "https://app.prodia.com" diff --git a/g4f/Provider/You.py b/g4f/Provider/You.py index 52546280..0b517150 100644 --- a/g4f/Provider/You.py +++ b/g4f/Provider/You.py @@ -7,8 +7,9 @@ import uuid from ..typing import AsyncResult, Messages, ImageType, Cookies from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import format_prompt -from ..image import ImageResponse, ImagePreview, EXTENSIONS_MAP, to_bytes, is_accepted_format +from ..image import EXTENSIONS_MAP, to_bytes, is_accepted_format from ..requests import StreamSession, FormData, raise_for_status, get_nodriver +from ..providers.response import ImagePreview, ImageResponse from ..cookies import get_cookies from ..errors import MissingRequirementsError, ResponseError from .. import debug diff --git a/g4f/Provider/hf/HuggingFaceAPI.py b/g4f/Provider/hf/HuggingFaceAPI.py index 9aa01dc2..08e03e44 100644 --- a/g4f/Provider/hf/HuggingFaceAPI.py +++ b/g4f/Provider/hf/HuggingFaceAPI.py @@ -48,18 +48,19 @@ class HuggingFaceAPI(OpenaiTemplate): if model in cls.model_aliases: model_name = cls.model_aliases[model] api_base = f"https://api-inference.huggingface.co/models/{model_name}/v1" - if images is not None: - async with StreamSession( - timeout=30, - ) as session: - async with session.get(f"https://huggingface.co/api/models/{model}") as response: - if response.status == 404: - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") - await raise_for_status(response) - model_data = await response.json() - pipeline_tag = model_data.get("pipeline_tag") - if pipeline_tag != "image-text-to-text": - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag={pipeline_tag}") + async with StreamSession( + timeout=30, + ) as session: + async with session.get(f"https://huggingface.co/api/models/{model}") as response: + if response.status == 404: + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + await raise_for_status(response) + model_data = await response.json() + pipeline_tag = model_data.get("pipeline_tag") + if images is None and pipeline_tag not in ("text-generation", "image-text-to-text"): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") + elif pipeline_tag != "image-text-to-text": + raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") start = calculate_lenght(messages) if start > max_inputs_lenght: if len(messages) > 6: diff --git a/g4f/Provider/hf/HuggingFaceInference.py b/g4f/Provider/hf/HuggingFaceInference.py index 68e6adff..c89a114e 100644 --- a/g4f/Provider/hf/HuggingFaceInference.py +++ b/g4f/Provider/hf/HuggingFaceInference.py @@ -9,14 +9,14 @@ from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError from ...requests import StreamSession, raise_for_status -from ...providers.response import FinishReason -from ...image import ImageResponse +from ...providers.response import FinishReason, ImageResponse from ..helper import format_image_prompt from .models import default_model, default_image_model, model_aliases, fallback_models from ... import debug class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin): url = "https://huggingface.co" + parent = "HuggingFace" working = True default_model = default_model diff --git a/g4f/Provider/hf/__init__.py b/g4f/Provider/hf/__init__.py index 5b69d133..260f02df 100644 --- a/g4f/Provider/hf/__init__.py +++ b/g4f/Provider/hf/__init__.py @@ -4,7 +4,7 @@ import random from ...typing import AsyncResult, Messages from ...providers.response import ImageResponse -from ...errors import ModelNotSupportedError +from ...errors import ModelNotSupportedError, MissingAuthError from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from .HuggingChat import HuggingChat from .HuggingFaceAPI import HuggingFaceAPI @@ -60,6 +60,6 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): try: async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs): yield chunk - except ModelNotSupportedError: + except (ModelNotSupportedError, MissingAuthError): async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs): yield chunk \ No newline at end of file diff --git a/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py b/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py index 48add133..60e6820d 100644 --- a/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py +++ b/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py @@ -4,7 +4,7 @@ import json from aiohttp import ClientSession from ...typing import AsyncResult, Messages -from ...image import ImageResponse, ImagePreview +from ...providers.response import ImageResponse, ImagePreview from ...errors import ResponseError from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_image_prompt diff --git a/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py b/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py index 2a1a2583..1f051bf7 100644 --- a/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py +++ b/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py @@ -4,7 +4,7 @@ from aiohttp import ClientSession import json from ...typing import AsyncResult, Messages -from ...image import ImageResponse +from ...providers.response import ImageResponse from ...errors import ResponseError from ...requests.raise_for_status import raise_for_status from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin diff --git a/g4f/Provider/hf_space/StableDiffusion35Large.py b/g4f/Provider/hf_space/StableDiffusion35Large.py index 82d349d8..3644244e 100644 --- a/g4f/Provider/hf_space/StableDiffusion35Large.py +++ b/g4f/Provider/hf_space/StableDiffusion35Large.py @@ -4,7 +4,7 @@ import json from aiohttp import ClientSession from ...typing import AsyncResult, Messages -from ...image import ImageResponse, ImagePreview +from ...providers.response import ImageResponse, ImagePreview from ...errors import ResponseError from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_image_prompt diff --git a/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py b/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py index 3496f718..2661714a 100644 --- a/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py +++ b/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py @@ -4,7 +4,7 @@ from aiohttp import ClientSession import json from ...typing import AsyncResult, Messages -from ...image import ImageResponse +from ...providers.response import ImageResponse from ...errors import ResponseError from ...requests.raise_for_status import raise_for_status from ..helper import format_image_prompt diff --git a/g4f/Provider/needs_auth/BingCreateImages.py b/g4f/Provider/needs_auth/BingCreateImages.py index b198d201..efce7ea9 100644 --- a/g4f/Provider/needs_auth/BingCreateImages.py +++ b/g4f/Provider/needs_auth/BingCreateImages.py @@ -1,7 +1,7 @@ from __future__ import annotations from ...cookies import get_cookies -from ...image import ImageResponse +from ...providers.response import ImageResponse from ...errors import MissingAuthError from ...typing import AsyncResult, Messages, Cookies from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin diff --git a/g4f/Provider/needs_auth/DeepInfra.py b/g4f/Provider/needs_auth/DeepInfra.py index b1d7c1d2..cc2d35e9 100644 --- a/g4f/Provider/needs_auth/DeepInfra.py +++ b/g4f/Provider/needs_auth/DeepInfra.py @@ -3,7 +3,7 @@ from __future__ import annotations import requests from ...typing import AsyncResult, Messages from ...requests import StreamSession, raise_for_status -from ...image import ImageResponse +from ...providers.response import ImageResponse from ..template import OpenaiTemplate from ..helper import format_image_prompt diff --git a/g4f/Provider/needs_auth/DeepSeekAPI.py b/g4f/Provider/needs_auth/DeepSeekAPI.py index f53be8cc..2c5a8bf7 100644 --- a/g4f/Provider/needs_auth/DeepSeekAPI.py +++ b/g4f/Provider/needs_auth/DeepSeekAPI.py @@ -48,12 +48,13 @@ try: raise MissingAuthError() response.raise_for_status() return response.json() + has_dsk = True except ImportError: - pass + has_dsk = False class DeepSeekAPI(AsyncAuthedProvider): url = "https://chat.deepseek.com" - working = False + working = has_dsk needs_auth = True use_nodriver = True _access_token = None @@ -91,6 +92,7 @@ class DeepSeekAPI(AsyncAuthedProvider): if conversation is None: chat_id = api.create_chat_session() conversation = JsonConversation(chat_id=chat_id) + yield conversation is_thinking = 0 for chunk in api.chat_completion( diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index 8c555e85..2d120ac4 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -19,12 +19,12 @@ from ... import debug from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt, get_cookies -from ...providers.response import JsonConversation, SynthesizeData, RequestLogin +from ...providers.response import JsonConversation, SynthesizeData, RequestLogin, ImageResponse from ...requests.raise_for_status import raise_for_status from ...requests.aiohttp import get_connector from ...requests import get_nodriver from ...errors import MissingAuthError -from ...image import ImageResponse, to_bytes +from ...image import to_bytes from ..helper import get_last_user_message from ... import debug diff --git a/g4f/Provider/needs_auth/MetaAI.py b/g4f/Provider/needs_auth/MetaAI.py index a8b6708e..710a4d38 100644 --- a/g4f/Provider/needs_auth/MetaAI.py +++ b/g4f/Provider/needs_auth/MetaAI.py @@ -10,7 +10,7 @@ from aiohttp import ClientSession, BaseConnector from ...typing import AsyncResult, Messages, Cookies from ...requests import raise_for_status, DEFAULT_HEADERS -from ...image import ImageResponse, ImagePreview +from ...providers.response import ImageResponse, ImagePreview from ...errors import ResponseError from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt, get_connector, format_cookies diff --git a/g4f/Provider/needs_auth/MicrosoftDesigner.py b/g4f/Provider/needs_auth/MicrosoftDesigner.py index 3ce8e618..19df3ecd 100644 --- a/g4f/Provider/needs_auth/MicrosoftDesigner.py +++ b/g4f/Provider/needs_auth/MicrosoftDesigner.py @@ -6,7 +6,7 @@ import random import asyncio import json -from ...image import ImageResponse +from ...providers.response import ImageResponse from ...errors import MissingRequirementsError, NoValidHarFileError from ...typing import AsyncResult, Messages from ...requests.raise_for_status import raise_for_status diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 81961bdf..8c216558 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -22,9 +22,9 @@ from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...requests.raise_for_status import raise_for_status from ...requests import StreamSession from ...requests import get_nodriver -from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format +from ...image import ImageRequest, to_image, to_bytes, is_accepted_format from ...errors import MissingAuthError, NoValidHarFileError -from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult +from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters, Reasoning from ..helper import format_cookies from ..openai.models import default_model, default_image_model, models, image_models, text_models @@ -526,7 +526,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin): c = m.get("content", {}) if c.get("content_type") == "text" and m.get("author", {}).get("role") == "tool" and "initial_text" in m.get("metadata", {}): fields.is_thinking = True - yield Reasoning(status=c.get("metadata", {}).get("initial_text")) + yield Reasoning(status=m.get("metadata", {}).get("initial_text")) if c.get("content_type") == "multimodal_text": generated_images = [] for element in c.get("parts"): diff --git a/g4f/Provider/not_working/AiChats.py b/g4f/Provider/not_working/AiChats.py index 51a85c91..e0bbc199 100644 --- a/g4f/Provider/not_working/AiChats.py +++ b/g4f/Provider/not_working/AiChats.py @@ -5,7 +5,7 @@ import base64 from aiohttp import ClientSession from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ...image import ImageResponse +from ...providers.response import ImageResponse from ..helper import format_prompt class AiChats(AsyncGeneratorProvider, ProviderModelMixin): diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 29cc5e0a..c22edccf 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -39,7 +39,8 @@ import g4f.debug from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider from g4f.providers.response import BaseConversation, JsonConversation from g4f.client.helper import filter_none -from g4f.image import is_data_uri_an_image, images_dir, copy_images +from g4f.image import is_data_uri_an_image +from g4f.image.copy_images import images_dir, copy_images from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.Provider import ProviderType, ProviderUtils, __providers__ diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 50a614c3..7ab87da9 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -9,10 +9,10 @@ import aiohttp import base64 from typing import Union, AsyncIterator, Iterator, Awaitable, Optional -from ..image import ImageResponse, copy_images +from ..image.copy_images import copy_images from ..typing import Messages, ImageType from ..providers.types import ProviderType, BaseRetryProvider -from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage +from ..providers.response import ResponseType, ImageResponse, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage from ..errors import NoImageResponseError from ..providers.retry_provider import IterListProvider from ..providers.asyncio import to_sync_generator diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index c7e445ce..73e640a5 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -165,7 +165,7 @@
- +
diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 738d4922..81706005 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -36,7 +36,8 @@ let title_storage = {}; let parameters_storage = {}; let finish_storage = {}; let usage_storage = {}; -let reasoning_storage = {} +let reasoning_storage = {}; +let generate_storage = {}; let is_demo = false; messageInput.addEventListener("blur", () => { @@ -96,9 +97,13 @@ function filter_message(text) { } function filter_message_content(text) { + return text.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "") +} + +function filter_message_image(text) { return text.replaceAll( - /\/\]\(\/generate\//gm, "/](/images/" - ).replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "") + /\]\(\/generate\//gm, "](/images/" + ) } function fallback_clipboard (text) { @@ -204,6 +209,7 @@ function register_message_images() { el.onerror = () => { let indexCommand; if ((indexCommand = el.src.indexOf("/generate/")) >= 0) { + generate_storage[window.conversation_id] = true; indexCommand = indexCommand + "/generate/".length + 1; let newPath = el.src.substring(indexCommand) let filename = newPath.replace(/(?:\?.+?|$)/, ""); @@ -973,7 +979,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi await add_message( window.conversation_id, "assistant", - final_message, + filter_message_image(final_message), message_provider, message_index, synthesize_storage[message_id], @@ -999,7 +1005,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi delete controller_storage[message_id]; } // Reload conversation if no error - if (!error_storage[message_id]) { + if (!error_storage[message_id] && !generate_storage[window.conversation_id]) { await safe_load_conversation(window.conversation_id, scroll); } let cursorDiv = message_el.querySelector(".cursor"); @@ -1022,14 +1028,23 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi } else { api_key = get_api_key_by_provider(provider); } - if (is_demo && !api_key && provider != "Custom") { + if (is_demo && !api_key) { + api_key = localStorage.getItem("HuggingFace-api_key"); + } + if (is_demo && !api_key) { location.href = "/"; return; } const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput; const files = input && input.files.length > 0 ? input.files : null; const download_images = document.getElementById("download_images")?.checked; - const api_base = provider == "Custom" ? document.getElementById(`${provider}-api_base`).value : null; + let api_base; + if (provider == "Custom") { + api_base = document.getElementById("api_base")?.value; + if (!api_base) { + provider = ""; + } + } const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value); await api("conversation", { id: message_id, @@ -1886,6 +1901,10 @@ async function on_load() { load_conversation(window.conversation_id); } else { chatPrompt.value = document.getElementById("systemPrompt")?.value || ""; + example = document.getElementById("systemPrompt")?.dataset.example || "" + if (chatPrompt.value == example) { + messageInput.value = ""; + } let chat_url = new URL(window.location.href) let chat_params = new URLSearchParams(chat_url.search); if (chat_params.get("prompt")) { @@ -2493,19 +2512,23 @@ async function load_provider_models(provider=null) { if (!custom_model.value) { custom_model.classList.add("hidden"); } - if (provider == "Custom Model" || custom_model.value) { + if (provider.startsWith("Custom") || custom_model.value) { modelProvider.classList.add("hidden"); modelSelect.classList.add("hidden"); - document.getElementById("model3").classList.remove("hidden"); + custom_model.classList.remove("hidden"); return; } modelProvider.innerHTML = ''; modelProvider.name = `model[${provider}]`; if (!provider) { modelProvider.classList.add("hidden"); - modelSelect.classList.remove("hidden"); - document.getElementById("model3").value = ""; - document.getElementById("model3").classList.remove("hidden"); + if (custom_model.value) { + modelSelect.classList.add("hidden"); + custom_model.classList.remove("hidden"); + } else { + modelSelect.classList.remove("hidden"); + custom_model.classList.add("hidden"); + } return; } const models = await api('models', provider); @@ -2531,11 +2554,9 @@ async function load_provider_models(provider=null) { modelProvider.value = value; } modelProvider.selectedIndex = defaultIndex; - } else if (custom_model.value) { - modelSelect.classList.add("hidden"); } else { modelProvider.classList.add("hidden"); - modelSelect.classList.remove("hidden"); + custom_model.classList.remove("hidden") } }; providerSelect.addEventListener("change", () => { diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 188617bf..c576541c 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -8,7 +8,7 @@ from flask import send_from_directory from inspect import signature from ...errors import VersionNotFoundError -from ...image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir +from ...image.copy_images import copy_images, ensure_images_dir, images_dir from ...tools.run_tools import iter_run_tools from ...Provider import ProviderUtils, __providers__ from ...providers.base_provider import ProviderModelMixin @@ -182,7 +182,9 @@ class Api: elif isinstance(chunk, Exception): logger.exception(chunk) yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__) - elif isinstance(chunk, (PreviewResponse, ImagePreview)): + elif isinstance(chunk, PreviewResponse): + yield self._format_json("preview", chunk.to_string()) + elif isinstance(chunk, ImagePreview): yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt) elif isinstance(chunk, ImageResponse): images = chunk @@ -207,6 +209,8 @@ class Api: yield self._format_json("reasoning", **chunk.get_dict()) elif isinstance(chunk, DebugResponse): yield self._format_json("log", chunk.log) + elif isinstance(chunk, RawResponse): + yield self._format_json(chunk.type, **chunk.get_dict()) else: yield self._format_json("content", str(chunk)) if debug.logs: @@ -215,6 +219,8 @@ class Api: debug.logs = [] except Exception as e: logger.exception(e) + if debug.logging: + debug.log_handler(get_error_message(e)) if debug.logs: for log in debug.logs: yield self._format_json("log", str(log)) @@ -222,7 +228,7 @@ class Api: yield self._format_json('error', type(e).__name__, message=get_error_message(e)) def _format_json(self, response_type: str, content = None, **kwargs): - if content is not None: + if content is not None and isinstance(response_type, str): return { 'type': response_type, response_type: content, diff --git a/g4f/gui/server/backend_api.py b/g4f/gui/server/backend_api.py index 16d27798..44127b96 100644 --- a/g4f/gui/server/backend_api.py +++ b/g4f/gui/server/backend_api.py @@ -140,9 +140,9 @@ class Backend_Api(Api): if model != "default" and model in models.demo_models: json_data["provider"] = random.choice(models.demo_models[model][1]) else: - json_data["model"] = models.demo_models["default"][0].name + if not model or model == "default": + json_data["model"] = models.demo_models["default"][0].name json_data["provider"] = random.choice(models.demo_models["default"][1]) - kwargs = self._prepare_conversation_kwargs(json_data, kwargs) return self.app.response_class( self._create_response_stream( diff --git a/g4f/image.py b/g4f/image.py index a99b1169..e9341953 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -3,15 +3,10 @@ from __future__ import annotations import os import re import io -import time -import uuid import base64 -import asyncio -import hashlib from urllib.parse import quote_plus from io import BytesIO from pathlib import Path -from aiohttp import ClientSession, ClientError try: from PIL.Image import open as open_image, new as new_image from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 @@ -21,9 +16,7 @@ except ImportError: from .typing import ImageType, Union, Image, Optional, Cookies from .errors import MissingRequirementsError -from .providers.response import ImageResponse, ImagePreview from .requests.aiohttp import get_connector -from . import debug ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'} @@ -237,68 +230,6 @@ def to_data_uri(image: ImageType) -> str: return f"data:{is_accepted_format(data)};base64,{data_base64}" return image -# Function to ensure the images directory exists -def ensure_images_dir(): - os.makedirs(images_dir, exist_ok=True) - -def get_image_extension(image: str) -> str: - match = re.search(r"\.(?:jpe?g|png|webp)", image) - if match: - return match.group(0) - return ".jpg" - -async def copy_images( - images: list[str], - cookies: Optional[Cookies] = None, - headers: Optional[dict] = None, - proxy: Optional[str] = None, - alt: str = None, - add_url: bool = True, - target: str = None, - ssl: bool = None -) -> list[str]: - if add_url: - add_url = not cookies - ensure_images_dir() - async with ClientSession( - connector=get_connector(proxy=proxy), - cookies=cookies, - headers=headers, - ) as session: - async def copy_image(image: str, target: str = None) -> str: - if target is None or len(images) > 1: - hash = hashlib.sha256(image.encode()).hexdigest() - target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4()) - target = f"{int(time.time())}_{target}{get_image_extension(image)}" - target = os.path.join(images_dir, target) - try: - if image.startswith("data:"): - with open(target, "wb") as f: - f.write(extract_data_uri(image)) - else: - try: - async with session.get(image, ssl=ssl) as response: - response.raise_for_status() - with open(target, "wb") as f: - async for chunk in response.content.iter_chunked(4096): - f.write(chunk) - except ClientError as e: - debug.log(f"copy_images failed: {e.__class__.__name__}: {e}") - return image - if "." not in target: - with open(target, "rb") as f: - extension = is_accepted_format(f.read(12)).split("/")[-1] - extension = "jpg" if extension == "jpeg" else extension - new_target = f"{target}.{extension}" - os.rename(target, new_target) - target = new_target - finally: - if "." not in target and os.path.exists(target): - os.unlink(target) - return f"/images/{os.path.basename(target)}{'?url=' + image if add_url and not image.startswith('data:') else ''}" - - return await asyncio.gather(*[copy_image(image, target) for image in images]) - class ImageDataResponse(): def __init__( self, diff --git a/g4f/image/__init__.py b/g4f/image/__init__.py new file mode 100644 index 00000000..38e83b8b --- /dev/null +++ b/g4f/image/__init__.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import os +import re +import io +import base64 +from io import BytesIO +from pathlib import Path +try: + from PIL.Image import open as open_image, new as new_image + from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 + has_requirements = True +except ImportError: + has_requirements = False + +from ..typing import ImageType, Union, Image +from ..errors import MissingRequirementsError + +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'} + +EXTENSIONS_MAP: dict[str, str] = { + "image/png": "png", + "image/jpeg": "jpg", + "image/gif": "gif", + "image/webp": "webp", +} + +def to_image(image: ImageType, is_svg: bool = False) -> Image: + """ + Converts the input image to a PIL Image object. + + Args: + image (Union[str, bytes, Image]): The input image. + + Returns: + Image: The converted PIL Image object. + """ + if not has_requirements: + raise MissingRequirementsError('Install "pillow" package for images') + + if isinstance(image, str) and image.startswith("data:"): + is_data_uri_an_image(image) + image = extract_data_uri(image) + + if is_svg: + try: + import cairosvg + except ImportError: + raise MissingRequirementsError('Install "cairosvg" package for svg images') + if not isinstance(image, bytes): + image = image.read() + buffer = BytesIO() + cairosvg.svg2png(image, write_to=buffer) + return open_image(buffer) + + if isinstance(image, bytes): + is_accepted_format(image) + return open_image(BytesIO(image)) + elif not isinstance(image, Image): + image = open_image(image) + image.load() + return image + + return image + +def is_allowed_extension(filename: str) -> bool: + """ + Checks if the given filename has an allowed extension. + + Args: + filename (str): The filename to check. + + Returns: + bool: True if the extension is allowed, False otherwise. + """ + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +def is_data_uri_an_image(data_uri: str) -> bool: + """ + Checks if the given data URI represents an image. + + Args: + data_uri (str): The data URI to check. + + Raises: + ValueError: If the data URI is invalid or the image format is not allowed. + """ + # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif) + if not re.match(r'data:image/(\w+);base64,', data_uri): + raise ValueError("Invalid data URI image.") + # Extract the image format from the data URI + image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower() + # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) + if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml": + raise ValueError("Invalid image format (from mime file type).") + +def is_accepted_format(binary_data: bytes) -> str: + """ + Checks if the given binary data represents an image with an accepted format. + + Args: + binary_data (bytes): The binary data to check. + + Raises: + ValueError: If the image format is not allowed. + """ + if binary_data.startswith(b'\xFF\xD8\xFF'): + return "image/jpeg" + elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'): + return "image/png" + elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'): + return "image/gif" + elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'): + return "image/jpeg" + elif binary_data.startswith(b'\xFF\xD8'): + return "image/jpeg" + elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP': + return "image/webp" + else: + raise ValueError("Invalid image format (from magic code).") + +def extract_data_uri(data_uri: str) -> bytes: + """ + Extracts the binary data from the given data URI. + + Args: + data_uri (str): The data URI. + + Returns: + bytes: The extracted binary data. + """ + data = data_uri.split(",")[-1] + data = base64.b64decode(data) + return data + +def get_orientation(image: Image) -> int: + """ + Gets the orientation of the given image. + + Args: + image (Image): The image. + + Returns: + int: The orientation value. + """ + exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() + if exif_data is not None: + orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF + if orientation is not None: + return orientation + +def process_image(image: Image, new_width: int, new_height: int) -> Image: + """ + Processes the given image by adjusting its orientation and resizing it. + + Args: + image (Image): The image to process. + new_width (int): The new width of the image. + new_height (int): The new height of the image. + + Returns: + Image: The processed image. + """ + # Fix orientation + orientation = get_orientation(image) + if orientation: + if orientation > 4: + image = image.transpose(FLIP_LEFT_RIGHT) + if orientation in [3, 4]: + image = image.transpose(ROTATE_180) + if orientation in [5, 6]: + image = image.transpose(ROTATE_270) + if orientation in [7, 8]: + image = image.transpose(ROTATE_90) + # Resize image + image.thumbnail((new_width, new_height)) + # Remove transparency + if image.mode == "RGBA": + image.load() + white = new_image('RGB', image.size, (255, 255, 255)) + white.paste(image, mask=image.split()[-1]) + return white + # Convert to RGB for jpg format + elif image.mode != "RGB": + image = image.convert("RGB") + return image + +def to_bytes(image: ImageType) -> bytes: + """ + Converts the given image to bytes. + + Args: + image (ImageType): The image to convert. + + Returns: + bytes: The image as bytes. + """ + if isinstance(image, bytes): + return image + elif isinstance(image, str) and image.startswith("data:"): + is_data_uri_an_image(image) + return extract_data_uri(image) + elif isinstance(image, Image): + bytes_io = BytesIO() + image.save(bytes_io, image.format) + image.seek(0) + return bytes_io.getvalue() + elif isinstance(image, (str, os.PathLike)): + return Path(image).read_bytes() + elif isinstance(image, Path): + return image.read_bytes() + else: + try: + image.seek(0) + except (AttributeError, io.UnsupportedOperation): + pass + return image.read() + +def to_data_uri(image: ImageType) -> str: + if not isinstance(image, str): + data = to_bytes(image) + data_base64 = base64.b64encode(data).decode() + return f"data:{is_accepted_format(data)};base64,{data_base64}" + return image + +class ImageDataResponse(): + def __init__( + self, + images: Union[str, list], + alt: str, + ): + self.images = images + self.alt = alt + + def get_list(self) -> list[str]: + return [self.images] if isinstance(self.images, str) else self.images + +class ImageRequest: + def __init__( + self, + options: dict = {} + ): + self.options = options + + def get(self, key: str): + return self.options.get(key) \ No newline at end of file diff --git a/g4f/image/copy_images.py b/g4f/image/copy_images.py new file mode 100644 index 00000000..051d10c7 --- /dev/null +++ b/g4f/image/copy_images.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import os +import time +import uuid +import asyncio +import hashlib +import re +from urllib.parse import quote_plus +from aiohttp import ClientSession, ClientError + +from ..typing import Optional, Cookies +from ..requests.aiohttp import get_connector +from ..Provider.template import BackendApi +from . import is_accepted_format, extract_data_uri +from .. import debug + +# Define the directory for generated images +images_dir = "./generated_images" + +def get_image_extension(image: str) -> str: + match = re.search(r"\.(?:jpe?g|png|webp)", image) + if match: + return match.group(0) + return ".jpg" + +# Function to ensure the images directory exists +def ensure_images_dir(): + os.makedirs(images_dir, exist_ok=True) + +async def copy_images( + images: list[str], + cookies: Optional[Cookies] = None, + headers: Optional[dict] = None, + proxy: Optional[str] = None, + alt: str = None, + add_url: bool = True, + target: str = None, + ssl: bool = None +) -> list[str]: + if add_url: + add_url = not cookies + ensure_images_dir() + async with ClientSession( + connector=get_connector(proxy=proxy), + cookies=cookies, + headers=headers, + ) as session: + async def copy_image(image: str, target: str = None, headers: dict = headers, ssl: bool = ssl) -> str: + if target is None or len(images) > 1: + hash = hashlib.sha256(image.encode()).hexdigest() + target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4()) + target = f"{int(time.time())}_{target}{get_image_extension(image)}" + target = os.path.join(images_dir, target) + try: + if image.startswith("data:"): + with open(target, "wb") as f: + f.write(extract_data_uri(image)) + else: + try: + if BackendApi.working and image.startswith(BackendApi.url) and headers is None: + headers = BackendApi.headers + ssl = BackendApi.ssl + async with session.get(image, ssl=ssl, headers=headers) as response: + response.raise_for_status() + with open(target, "wb") as f: + async for chunk in response.content.iter_chunked(4096): + f.write(chunk) + except ClientError as e: + debug.log(f"copy_images failed: {e.__class__.__name__}: {e}") + return image + if "." not in target: + with open(target, "rb") as f: + extension = is_accepted_format(f.read(12)).split("/")[-1] + extension = "jpg" if extension == "jpeg" else extension + new_target = f"{target}.{extension}" + os.rename(target, new_target) + target = new_target + finally: + if "." not in target and os.path.exists(target): + os.unlink(target) + return f"/images/{os.path.basename(target)}{'?url=' + image if add_url and not image.startswith('data:') else ''}" + + return await asyncio.gather(*[copy_image(image, target) for image in images]) diff --git a/g4f/models.py b/g4f/models.py index c127c302..adfcfaa6 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -765,11 +765,11 @@ class ModelUtils: demo_models = { gpt_4o.name: [gpt_4o, [PollinationsAI, Blackbox]], - "default": [llama_3_2_11b, [HuggingFaceAPI]], + "default": [llama_3_2_11b, [HuggingFace]], qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]], - qvq_72b.name: [qvq_72b, [HuggingSpace, HuggingFaceAPI]], - deepseek_r1.name: [deepseek_r1, [HuggingFace, HuggingFaceAPI]], - claude_3_haiku.name: [claude_3_haiku, [DDG, Jmuz]], + qvq_72b.name: [qvq_72b, [HuggingSpace]], + deepseek_r1.name: [deepseek_r1, [HuggingFace]], + claude_3_5_sonnet.name: [claude_3_5_sonnet, claude_3_5_sonnet.best_provider.providers], command_r.name: [command_r, [HuggingSpace]], command_r_plus.name: [command_r_plus, [HuggingSpace]], command_r7b.name: [command_r7b, [HuggingSpace]], @@ -779,7 +779,7 @@ demo_models = { llama_3_3_70b.name: [llama_3_3_70b, [HuggingFace]], sd_3_5.name: [sd_3_5, [HuggingSpace, HuggingFace]], flux_dev.name: [flux_dev, [HuggingSpace, HuggingFace]], - flux_schnell.name: [flux_schnell, [HuggingFace]], + flux_schnell.name: [flux_schnell, [HuggingFace, HuggingSpace, PollinationsAI]], } # Create a list of all models and his providers diff --git a/g4f/tools/run_tools.py b/g4f/tools/run_tools.py index 54a9b237..786342d2 100644 --- a/g4f/tools/run_tools.py +++ b/g4f/tools/run_tools.py @@ -159,26 +159,27 @@ def iter_run_tools( chunk = chunk.split("", 1) if len(chunk) > 0 and chunk[0]: yield chunk[0] - yield Reasoning(None, "🤔 Is thinking...", is_thinking="") + yield Reasoning(status="🤔 Is thinking...", is_thinking="") if chunk != "": if len(chunk) > 1 and chunk[1]: yield Reasoning(chunk[1]) is_thinking = time.time() - if "" in chunk: - if chunk != "": - chunk = chunk.split("", 1) - if len(chunk) > 0 and chunk[0]: - yield Reasoning(chunk[0]) - is_thinking = time.time() - is_thinking - if is_thinking > 1: - yield Reasoning(None, f"Thought for {is_thinking:.2f}s", is_thinking="") - else: - yield Reasoning(None, f"Finished", is_thinking="") - if chunk != "": - if len(chunk) > 1 and chunk[1]: - yield chunk[1] - is_thinking = 0 - elif is_thinking: - yield Reasoning(chunk) else: - yield chunk + if "" in chunk: + if chunk != "": + chunk = chunk.split("", 1) + if len(chunk) > 0 and chunk[0]: + yield Reasoning(chunk[0]) + is_thinking = time.time() - is_thinking if is_thinking > 0 else 0 + if is_thinking > 1: + yield Reasoning(status=f"Thought for {is_thinking:.2f}s", is_thinking="") + else: + yield Reasoning(status=f"Finished", is_thinking="") + if chunk != "": + if len(chunk) > 1 and chunk[1]: + yield chunk[1] + is_thinking = 0 + elif is_thinking: + yield Reasoning(chunk) + else: + yield chunk diff --git a/g4f/tools/web_search.py b/g4f/tools/web_search.py index f1d2043b..9d3e2621 100644 --- a/g4f/tools/web_search.py +++ b/g4f/tools/web_search.py @@ -194,8 +194,10 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str: if instructions and instructions in prompt: return prompt # We have already added search results + if prompt.startswith("##") and query is None: + return prompt # We have no search query if query is None: - query = spacy_get_keywords(prompt) + query = prompt.strip().splitlines()[0] # Use the first line as the search query json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore") md5_hash = hashlib.md5(json_bytes).hexdigest() bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search" / f"{datetime.date.today()}" From 9c8aedbeb1f33c6b6f87f59892f76cfff94306d8 Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Mon, 3 Feb 2025 20:23:55 +0100 Subject: [PATCH 2/3] Improve select custom model in UI Updates for the response of the BackendApi Update of the demo model list Improve web search tool Moved copy_images to /image --- g4f/Provider/not_working/Airforce.py | 3 +-- g4f/Provider/not_working/AmigoChat.py | 2 +- g4f/Provider/not_working/ReplicateHome.py | 2 +- g4f/Provider/template/BackendApi.py | 3 ++- g4f/providers/create_images.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/g4f/Provider/not_working/Airforce.py b/g4f/Provider/not_working/Airforce.py index 862e59d9..003b7fa2 100644 --- a/g4f/Provider/not_working/Airforce.py +++ b/g4f/Provider/not_working/Airforce.py @@ -6,8 +6,7 @@ from aiohttp import ClientSession from typing import List from ...typing import AsyncResult, Messages -from ...image import ImageResponse -from ...providers.response import FinishReason, Usage +from ...providers.response import ImageResponse, FinishReason, Usage from ...requests.raise_for_status import raise_for_status from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin diff --git a/g4f/Provider/not_working/AmigoChat.py b/g4f/Provider/not_working/AmigoChat.py index 31d1b10b..77190188 100644 --- a/g4f/Provider/not_working/AmigoChat.py +++ b/g4f/Provider/not_working/AmigoChat.py @@ -5,7 +5,7 @@ import uuid from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ...image import ImageResponse +from ...providers.response import ImageResponse from ...requests import StreamSession, raise_for_status from ...errors import ResponseStatusError diff --git a/g4f/Provider/not_working/ReplicateHome.py b/g4f/Provider/not_working/ReplicateHome.py index e8a99e83..cafc4ce0 100644 --- a/g4f/Provider/not_working/ReplicateHome.py +++ b/g4f/Provider/not_working/ReplicateHome.py @@ -9,7 +9,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...requests.aiohttp import get_connector from ...requests.raise_for_status import raise_for_status from ..helper import format_prompt -from ...image import ImageResponse +from ...providers.response import ImageResponse class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): url = "https://replicate.com" diff --git a/g4f/Provider/template/BackendApi.py b/g4f/Provider/template/BackendApi.py index 0bca9a22..91067006 100644 --- a/g4f/Provider/template/BackendApi.py +++ b/g4f/Provider/template/BackendApi.py @@ -10,6 +10,7 @@ from ... import debug class BackendApi(AsyncGeneratorProvider, ProviderModelMixin): ssl = None + headers = {} @classmethod async def create_async_generator( @@ -21,7 +22,7 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin): ) -> AsyncResult: debug.log(f"{cls.__name__}: {api_key}") async with StreamSession( - headers={"Accept": "text/event-stream"}, + headers={"Accept": "text/event-stream", **cls.headers}, ) as session: async with session.post(f"{cls.url}/backend-api/v2/conversation", json={ "model": model, diff --git a/g4f/providers/create_images.py b/g4f/providers/create_images.py index 29db9435..ab43436e 100644 --- a/g4f/providers/create_images.py +++ b/g4f/providers/create_images.py @@ -6,7 +6,7 @@ import asyncio from .. import debug from ..typing import CreateResult, Messages from .types import BaseProvider, ProviderType -from ..image import ImageResponse +from ..providers.response import ImageResponse system_message = """ You can generate images, pictures, photos or img with the DALL-E 3 image generator. From b86bf0dcf7c288760dc93e44e41a4ee5d0d5859b Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Mon, 3 Feb 2025 21:24:37 +0100 Subject: [PATCH 3/3] Fix include in mocks.py --- etc/unittest/mocks.py | 2 +- g4f/requests/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index 50d1a5a4..a903ebfd 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -1,5 +1,5 @@ from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider -from g4f.image import ImageResponse +from g4f.providers.response import ImageResponse from g4f.errors import MissingAuthError class ProviderMock(AbstractProvider): diff --git a/g4f/requests/__init__.py b/g4f/requests/__init__.py index d1f5d024..88b8a47a 100644 --- a/g4f/requests/__init__.py +++ b/g4f/requests/__init__.py @@ -137,7 +137,7 @@ async def get_nodriver( timeout: int = 120, browser_executable_path=None, **kwargs -) -> Browser: +) -> tuple[Browser, callable]: if not has_nodriver: raise MissingRequirementsError('Install "nodriver" and "platformdirs" package | pip install -U nodriver platformdirs') user_data_dir = user_config_dir(f"g4f-{user_data_dir}") if has_platformdirs else None