diff --git a/docs/pydantic_ai.md b/docs/pydantic_ai.md index f7a1dcec..3f8af6a5 100644 --- a/docs/pydantic_ai.md +++ b/docs/pydantic_ai.md @@ -109,15 +109,18 @@ This example shows how to initialize an agent with a specific model (`gpt-4o`) a from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models import ModelSettings -from g4f.integration.pydantic_ai import patch_infer_model +from g4f.integration.pydantic_ai import AIModel +from g4f.Provider import PollinationsAI -patch_infer_model("your_api_key") class MyModel(BaseModel): city: str country: str -agent = Agent('g4f:Groq:llama3-70b-8192', result_type=MyModel, model_settings=ModelSettings(temperature=0)) +nt = Agent(AIModel( + "gpt-4o", # Specify the provider and model + PollinationsAI # Use a supported provider to handle tool-based response formatting +), result_type=MyModel, model_settings=ModelSettings(temperature=0)) if __name__ == '__main__': result = agent.run_sync('The windy city in the US of A.') @@ -152,7 +155,7 @@ class MyModel(BaseModel): # Create the agent for a model with tool support (using one tool) agent = Agent(AIModel( - "PollinationsAI:openai", # Specify the provider and model + "OpenaiChat:gpt-4o", # Specify the provider and model ToolSupportProvider # Use ToolSupportProvider to handle tool-based response formatting ), result_type=MyModel, model_settings=ModelSettings(temperature=0)) diff --git a/g4f/Provider/ARTA.py b/g4f/Provider/ARTA.py new file mode 100644 index 00000000..993a85a1 --- /dev/null +++ b/g4f/Provider/ARTA.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import os +import time +import json +from pathlib import Path +from aiohttp import ClientSession +import asyncio + +from ..typing import AsyncResult, Messages +from ..providers.response import ImageResponse, Reasoning +from ..errors import ResponseError +from ..cookies import get_cookies_dir +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from .helper import format_image_prompt + +class ARTA(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://img-gen-prod.ai-arta.com" + auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ" + token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ" + image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image" + status_check_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image/{record_id}/status" + + working = True + + default_model = "Flux" + default_image_model = default_model + model_aliases = { + "flux": "Flux", + "medieval": "Medieval", + "vincent_van_gogh": "Vincent Van Gogh", + "f_dev": "F Dev", + "low_poly": "Low Poly", + "dreamshaper_xl": "Dreamshaper-xl", + "anima_pencil_xl": "Anima-pencil-xl", + "biomech": "Biomech", + "trash_polka": "Trash Polka", + "no_style": "No Style", + "cheyenne_xl": "Cheyenne-xl", + "chicano": "Chicano", + "embroidery_tattoo": "Embroidery tattoo", + "red_and_black": "Red and Black", + "fantasy_art": "Fantasy Art", + "watercolor": "Watercolor", + "dotwork": "Dotwork", + "old_school_colored": "Old school colored", + "realistic_tattoo": "Realistic tattoo", + "japanese_2": "Japanese_2", + "realistic_stock_xl": "Realistic-stock-xl", + "f_pro": "F Pro", + "revanimated": "RevAnimated", + "katayama_mix_xl": "Katayama-mix-xl", + "sdxl_l": "SDXL L", + "cor_epica_xl": "Cor-epica-xl", + "anime_tattoo": "Anime tattoo", + "new_school": "New School", + "death_metal": "Death metal", + "old_school": "Old School", + "juggernaut_xl": "Juggernaut-xl", + "photographic": "Photographic", + "sdxl_1_0": "SDXL 1.0", + "graffiti": "Graffiti", + "mini_tattoo": "Mini tattoo", + "surrealism": "Surrealism", + "neo_traditional": "Neo-traditional", + "on_limbs_black": "On limbs black", + "yamers_realistic_xl": "Yamers-realistic-xl", + "pony_xl": "Pony-xl", + "playground_xl": "Playground-xl", + "anything_xl": "Anything-xl", + "flame_design": "Flame design", + "kawaii": "Kawaii", + "cinematic_art": "Cinematic Art", + "professional": "Professional", + "flux_black_ink": "Flux Black Ink" + } + image_models = [*model_aliases.keys()] + models = image_models + + @classmethod + def get_auth_file(cls): + path = Path(get_cookies_dir()) + path.mkdir(exist_ok=True) + filename = f"auth_{cls.__name__}.json" + return path / filename + + @classmethod + async def create_token(cls, path: Path, proxy: str | None = None): + async with ClientSession() as session: + # Step 1: Generate Authentication Token + auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"} + async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response: + auth_data = await auth_response.json() + auth_token = auth_data.get("idToken") + #refresh_token = auth_data.get("refreshToken") + if not auth_token: + raise ResponseError("Failed to obtain authentication token.") + json.dump(auth_data, path.open("w")) + return auth_data + + @classmethod + async def refresh_token(cls, refresh_token: str, proxy: str = None) -> tuple[str, str]: + async with ClientSession() as session: + payload = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response: + response_data = await response.json() + return response_data.get("id_token"), response_data.get("refresh_token") + + @classmethod + async def read_and_refresh_token(cls, proxy: str | None = None) -> str: + path = cls.get_auth_file() + if path.is_file(): + auth_data = json.load(path.open("rb")) + diff = time.time() - os.path.getmtime(path) + expiresIn = int(auth_data.get("expiresIn")) + if diff < expiresIn: + if diff > expiresIn / 2: + auth_data["idToken"], auth_data["refreshToken"] = await cls.refresh_token(auth_data.get("refreshToken"), proxy) + json.dump(auth_data, path.open("w")) + return auth_data + return await cls.create_token(path, proxy) + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + proxy: str = None, + prompt: str = None, + negative_prompt: str = "blurry, deformed hands, ugly", + images_num: int = 1, + guidance_scale: int = 7, + num_inference_steps: int = 30, + aspect_ratio: str = "1:1", + **kwargs + ) -> AsyncResult: + model = cls.get_model(model) + prompt = format_image_prompt(messages, prompt) + + # Step 1: Get Authentication Token + auth_data = await cls.read_and_refresh_token(proxy) + + async with ClientSession() as session: + # Step 2: Generate Images + image_payload = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "style": model, + "images_num": str(images_num), + "cfg_scale": str(guidance_scale), + "steps": str(num_inference_steps), + "aspect_ratio": aspect_ratio, + } + + headers = { + "Authorization": auth_data.get("idToken"), + } + + async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response: + image_data = await image_response.json() + record_id = image_data.get("record_id") + + if not record_id: + raise ResponseError(f"Failed to initiate image generation: {image_data}") + + # Step 3: Check Generation Status + status_url = cls.status_check_url.format(record_id=record_id) + counter = 0 + while True: + async with session.get(status_url, headers=headers, proxy=proxy) as status_response: + status_data = await status_response.json() + status = status_data.get("status") + + if status == "DONE": + image_urls = [image["url"] for image in status_data.get("response", [])] + yield Reasoning(status="Finished") + yield ImageResponse(images=image_urls, alt=prompt) + return + elif status in ("IN_QUEUE", "IN_PROGRESS"): + yield Reasoning(status=("Waiting" if status == "IN_QUEUE" else "Generating") + "." * counter) + await asyncio.sleep(5) # Poll every 5 seconds + counter += 1 + if counter > 3: + counter = 0 + else: + raise ResponseError(f"Image generation failed with status: {status}") \ No newline at end of file diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 3f043322..9bd169f7 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import random import requests from urllib.parse import quote_plus @@ -13,7 +14,7 @@ from ..image import to_data_uri from ..errors import ModelNotFoundError from ..requests.raise_for_status import raise_for_status from ..requests.aiohttp import get_connector -from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio +from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls from .. import debug DEFAULT_HEADERS = { @@ -52,7 +53,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): text_models = [default_model] image_models = [default_image_model] extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"] - vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini"] + vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini", "openai", "openai-large"] extra_text_models = vision_models _models_loaded = False model_aliases = { @@ -138,6 +139,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): cls, model: str, messages: Messages, + stream: bool = False, proxy: str = None, prompt: str = None, width: int = 1024, @@ -154,6 +156,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): frequency_penalty: float = None, response_format: Optional[dict] = None, cache: bool = False, + extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"], **kwargs ) -> AsyncResult: cls.get_models() @@ -193,6 +196,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): response_format=response_format, seed=seed, cache=cache, + stream=stream, + extra_parameters=extra_parameters, + **kwargs ): yield result @@ -246,7 +252,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): frequency_penalty: float, response_format: Optional[dict], seed: Optional[int], - cache: bool + cache: bool, + stream: bool, + extra_parameters: list[str], + **kwargs ) -> AsyncResult: if not cache and seed is None: seed = random.randint(9999, 99999999) @@ -267,6 +276,13 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): messages[-1] = last_message async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session: + if model in cls.audio_models or stream: + #data["voice"] = random.choice(cls.audio_models[model]) + url = cls.text_api_endpoint + stream = False + else: + url = cls.openai_endpoint + extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs} data = filter_none(**{ "messages": messages, "model": model, @@ -275,17 +291,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): "top_p": top_p, "frequency_penalty": frequency_penalty, "jsonMode": json_mode, - "stream": False, + "stream": stream, "seed": seed, - "cache": cache + "cache": cache, + **extra_parameters }) - if "gemini" in model: - data.pop("seed") - if model in cls.audio_models: - #data["voice"] = random.choice(cls.audio_models[model]) - url = cls.text_api_endpoint - else: - url = cls.openai_endpoint async with session.post(url, json=data) as response: await raise_for_status(response) if response.headers["content-type"] == "audio/mpeg": @@ -294,16 +304,36 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): elif response.headers["content-type"].startswith("text/plain"): yield await response.text() return + elif response.headers["content-type"].startswith("text/event-stream"): + async for line in response.content: + if line.startswith(b"data: "): + if line[6:].startswith(b"[DONE]"): + break + result = json.loads(line[6:]) + choice = result.get("choices", [{}])[0] + content = choice.get("delta", {}).get("content") + if content: + yield content + if "usage" in result: + yield Usage(**result["usage"]) + finish_reason = choice.get("finish_reason") + if finish_reason: + yield FinishReason(finish_reason) + return result = await response.json() choice = result["choices"][0] message = choice.get("message", {}) content = message.get("content", "") - if "" in content and "" not in content: - yield "" + if "tool_calls" in message: + yield ToolCalls(message["tool_calls"]) - if content: - yield content.replace("\\(", "(").replace("\\)", ")") + if content is not None: + if "" in content and "" not in content: + yield "" + + if content: + yield content.replace("\\(", "(").replace("\\)", ")") if "usage" in result: yield Usage(**result["usage"]) diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 0915c796..6e39f861 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -15,6 +15,7 @@ from .mini_max import HailuoAI, MiniMax from .template import OpenaiTemplate, BackendApi from .AllenAI import AllenAI +from .ARTA import ARTA from .Blackbox import Blackbox from .ChatGLM import ChatGLM from .ChatGpt import ChatGpt diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 4670e77e..3aa37bef 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -623,8 +623,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin): page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request) page = await browser.get(cls.url) user_agent = await page.evaluate("window.navigator.userAgent") - await page.select("textarea.text-token-text-primary", 240) - await page.evaluate("document.querySelector('textarea.text-token-text-primary').value = 'Hello'") + await page.select("#prompt-textarea", 240) + await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'") + await page.select("[data-testid=\"send-button\"]", 30) await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()") while True: body = await page.evaluate("JSON.stringify(window.__remixContext)") diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 0c4dd1dd..162a7d20 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -276,7 +276,7 @@ class Completions: def create( self, messages: Messages, - model: str, + model: str = "", provider: Optional[ProviderType] = None, stream: Optional[bool] = False, proxy: Optional[str] = None, @@ -330,7 +330,7 @@ class Completions: def stream( self, messages: Messages, - model: str, + model: str = "", **kwargs ) -> IterResponse: return self.create(messages, model, stream=True, **kwargs) @@ -564,7 +564,7 @@ class AsyncCompletions: def create( self, messages: Messages, - model: str, + model: str = "", provider: Optional[ProviderType] = None, stream: Optional[bool] = False, proxy: Optional[str] = None, @@ -619,7 +619,7 @@ class AsyncCompletions: def stream( self, messages: Messages, - model: str, + model: str = "", **kwargs ) -> AsyncIterator[ChatCompletionChunk]: return self.create(messages, model, stream=True, **kwargs) diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index 8b608970..7ab0b1b9 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -112,7 +112,7 @@ body:not(.white) a:visited{ .new_version { position: absolute; - right: 0; + left: 0; top: 0; padding: 10px; font-weight: 500; @@ -143,6 +143,7 @@ body:not(.white) a:visited{ .conversation { width: 100%; + height: 100%; display: flex; flex-direction: column; gap: 5px; @@ -238,8 +239,7 @@ body:not(.white) a:visited{ #close_provider_forms { max-width: 210px; - margin-left: auto; - margin-right: 8px; + margin-left: 12px; margin-top: 12px; } @@ -1584,19 +1584,6 @@ form .field.saved .fa-xmark { } } - -/* Basic adaptation */ -.row { - flex-wrap: wrap; - gap: 10px; -} - -.conversations, .settings, .conversation { - flex: 1 1 300px; - min-width: 0; - height: 100%; -} - /* Media queries for mobile devices */ @media (max-width: 768px) { .row { @@ -1608,11 +1595,6 @@ form .field.saved .fa-xmark { max-width: 100%; margin: 0; } - - .conversation { - order: -1; - min-height: 80vh; - } } @media (max-width: 480px) { diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index a82ab43a..d7117c50 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -259,6 +259,10 @@ function register_message_images() { const register_message_buttons = async () => { message_box.querySelectorAll(".message .content .provider").forEach(async (el) => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const provider_forms = document.querySelector(".provider_forms"); const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`); const provider_link = el.querySelector("a"); @@ -279,6 +283,10 @@ const register_message_buttons = async () => { }); message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const message_el = get_message_el(el); await remove_message(window.conversation_id, message_el.dataset.index); message_el.remove(); @@ -286,6 +294,10 @@ const register_message_buttons = async () => { })); message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; let message_el = get_message_el(el); let response = await fetch(message_el.dataset.object_url); let copyText = await response.text(); @@ -304,6 +316,10 @@ const register_message_buttons = async () => { })) message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const elem = window.document.createElement('a'); let filename = `chat ${new Date().toLocaleString()}.txt`.replaceAll(":", "-"); const conversation = await get_conversation(window.conversation_id); @@ -323,6 +339,10 @@ const register_message_buttons = async () => { })) message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const message_el = get_message_el(el); let audio; if (message_el.dataset.synthesize_url) { @@ -344,6 +364,10 @@ const register_message_buttons = async () => { })); message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const message_el = get_message_el(el); el.classList.add("clicked"); setTimeout(() => el.classList.remove("clicked"), 1000); @@ -351,6 +375,10 @@ const register_message_buttons = async () => { })); message_box.querySelectorAll(".message .continue_button").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; if (!el.disabled) { el.disabled = true; const message_el = get_message_el(el); @@ -361,11 +389,19 @@ const register_message_buttons = async () => { )); message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const text = get_message_el(el).innerText; window.open(`https://wa.me/?text=${encodeURIComponent(text)}`, '_blank'); })); message_box.querySelectorAll(".message .fa-print").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; const message_el = get_message_el(el); el.classList.add("clicked"); message_box.scrollTop = 0; @@ -378,6 +414,10 @@ const register_message_buttons = async () => { })); message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => el.addEventListener("click", async () => { + if (el.dataset.click) { + return + } + el.dataset.click = true; let text_el = el.parentElement.querySelector(".reasoning_text"); if (text_el) { text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden"); @@ -569,9 +609,9 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_ } // Remove history, only add new user messages - let filtered_messages = []; // The message_index is null on count total tokens - if (document.getElementById('history')?.checked && do_filter && message_index != null) { + if (!do_continue && document.getElementById('history')?.checked && do_filter && message_index != null) { + let filtered_messages = []; while (last_message = messages.pop()) { if (last_message["role"] == "user") { filtered_messages.push(last_message); @@ -630,9 +670,9 @@ async function load_provider_parameters(provider) { form_el.id = form_id; form_el.classList.add("hidden"); appStorage.setItem(form_el.id, JSON.stringify(parameters_storage[provider])); - let old_form = message_box.querySelector(`#${provider}-form`); + let old_form = document.getElementById(form_id); if (old_form) { - provider_forms.removeChild(old_form); + old_form.remove(); } Object.entries(parameters_storage[provider]).forEach(([key, value]) => { let el_id = `${provider}-${key}`; @@ -649,7 +689,7 @@ async function load_provider_parameters(provider) { saved_value = value; } field_el.innerHTML = `${key}: - + `; form_el.appendChild(field_el); @@ -679,15 +719,15 @@ async function load_provider_parameters(provider) { placeholder = value == null ? "null" : value; } field_el.innerHTML = ``; - if (Number.isInteger(value) && value != 1) { - max = value >= 4096 ? 8192 : 4096; - field_el.innerHTML += `${escapeHtml(value)}`; + if (Number.isInteger(value)) { + max = value == 42 || value >= 4096 ? 8192 : value >= 100 ? 4096 : value == 1 ? 10 : 100; + field_el.innerHTML += `${escapeHtml(value)}`; field_el.innerHTML += ``; } else if (typeof value == "number") { - field_el.innerHTML += `${escapeHtml(value)}`; + field_el.innerHTML += `${escapeHtml(value)}`; field_el.innerHTML += ``; } else { - field_el.innerHTML += ``; + field_el.innerHTML += ``; field_el.innerHTML += ``; input_el = field_el.querySelector("textarea"); if (value != null) { @@ -723,6 +763,7 @@ async function load_provider_parameters(provider) { input_el = field_el.querySelector("input"); input_el.dataset.value = value; input_el.value = saved_value; + input_el.nextElementSibling.value = input_el.value; input_el.oninput = () => { input_el.nextElementSibling.value = input_el.value; field_el.classList.add("saved"); @@ -1008,6 +1049,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi } await safe_remove_cancel_button(); await register_message_images(); + await register_message_buttons(); await load_conversations(); regenerate_button.classList.remove("regenerate-hidden"); } @@ -1035,6 +1077,18 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi } } const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value); + let extra_parameters = {}; + document.getElementById(`${provider}-form`)?.querySelectorAll(".saved input, .saved textarea").forEach(async (el) => { + let value = el.type == "checkbox" ? el.checked : el.value; + extra_parameters[el.name] = value; + if (el.type == "textarea") { + try { + extra_parameters[el.name] = await JSON.parse(value); + } catch (e) { + } + } + }); + console.log(extra_parameters); await api("conversation", { id: message_id, conversation_id: window.conversation_id, @@ -1048,6 +1102,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi api_key: api_key, api_base: api_base, ignored: ignored, + ...extra_parameters }, Object.values(image_storage), message_id, scroll, finish_message); } catch (e) { console.error(e); diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 51d65426..c237fc1b 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -89,25 +89,17 @@ class Api: ensure_images_dir() return send_from_directory(os.path.abspath(images_dir), name) - def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict): + def _prepare_conversation_kwargs(self, json_data: dict): + kwargs = {**json_data} model = json_data.get('model') provider = json_data.get('provider') messages = json_data.get('messages') - api_key = json_data.get("api_key") - if api_key: - kwargs["api_key"] = api_key - api_base = json_data.get("api_base") - if api_base: - kwargs["api_base"] = api_base kwargs["tool_calls"] = [{ "function": { "name": "bucket_tool" }, "type": "function" }] - web_search = json_data.get('web_search') - if web_search: - kwargs["web_search"] = web_search action = json_data.get('action') if action == "continue": kwargs["tool_calls"].append({ @@ -117,19 +109,13 @@ class Api: "type": "function" }) conversation = json_data.get("conversation") - if conversation is not None: + if isinstance(conversation, dict): kwargs["conversation"] = JsonConversation(**conversation) else: conversation_id = json_data.get("conversation_id") if conversation_id and provider: if provider in conversations and conversation_id in conversations[provider]: kwargs["conversation"] = conversations[provider][conversation_id] - - if json_data.get("ignored"): - kwargs["ignored"] = json_data["ignored"] - if json_data.get("action"): - kwargs["action"] = json_data["action"] - return { "model": model, "provider": provider, diff --git a/g4f/gui/server/backend_api.py b/g4f/gui/server/backend_api.py index 1dc66223..8510e614 100644 --- a/g4f/gui/server/backend_api.py +++ b/g4f/gui/server/backend_api.py @@ -106,17 +106,16 @@ class Backend_Api(Api): Returns: Response: A Flask response object for streaming. """ - kwargs = {} + if "json" in request.form: + json_data = json.loads(request.form['json']) + else: + json_data = request.json if "files" in request.files: images = [] for file in request.files.getlist('files'): if file.filename != '' and is_allowed_extension(file.filename): images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename)) - kwargs['images'] = images - if "json" in request.form: - json_data = json.loads(request.form['json']) - else: - json_data = request.json + json_data['images'] = images if app.demo and not json_data.get("provider"): model = json_data.get("model") @@ -126,9 +125,7 @@ class Backend_Api(Api): if not model or model == "default": json_data["model"] = models.demo_models["default"][0].name json_data["provider"] = random.choice(models.demo_models["default"][1]) - if "images" in json_data: - kwargs["images"] = json_data["images"] - kwargs = self._prepare_conversation_kwargs(json_data, kwargs) + kwargs = self._prepare_conversation_kwargs(json_data) return self.app.response_class( self._create_response_stream( kwargs, diff --git a/g4f/gui/server/js_api.py b/g4f/gui/server/js_api.py index 02132ffe..ed47f09b 100644 --- a/g4f/gui/server/js_api.py +++ b/g4f/gui/server/js_api.py @@ -21,12 +21,12 @@ from .api import Api class JsApi(Api): - def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None, **kwargs) -> Iterator: + def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None) -> Iterator: window = webview.windows[0] if hasattr(self, "image") and self.image is not None: - kwargs["image"] = open(self.image, "rb") + options["image"] = open(self.image, "rb") for message in self._create_response_stream( - self._prepare_conversation_kwargs(options, kwargs), + self._prepare_conversation_kwargs(options), options.get("conversation_id"), options.get('provider') ): diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 7ae5b44a..9393cd47 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -34,7 +34,7 @@ SAFE_PARAMETERS = [ "api_key", "api_base", "seed", "width", "height", "proof_token", "max_retries", "web_search", "guidance_scale", "num_inference_steps", "randomize_seed", - "safe", "enhance", "private", + "safe", "enhance", "private", "aspect_ratio", "images_num", ] BASIC_PARAMETERS = { diff --git a/g4f/providers/tool_support.py b/g4f/providers/tool_support.py index 2aa83cef..21465728 100644 --- a/g4f/providers/tool_support.py +++ b/g4f/providers/tool_support.py @@ -3,11 +3,10 @@ from __future__ import annotations import json from ..typing import AsyncResult, Messages, ImagesType -from ..providers.asyncio import to_async_iterator from ..client.service import get_model_and_provider from ..client.helper import filter_json from .base_provider import AsyncGeneratorProvider -from .response import ToolCalls, FinishReason +from .response import ToolCalls, FinishReason, Usage class ToolSupportProvider(AsyncGeneratorProvider): working = True @@ -45,6 +44,7 @@ class ToolSupportProvider(AsyncGeneratorProvider): finish = None chunks = [] + has_usage = False async for chunk in provider.get_async_create_function()( model, messages, @@ -53,14 +53,20 @@ class ToolSupportProvider(AsyncGeneratorProvider): response_format=response_format, **kwargs ): - if isinstance(chunk, FinishReason): + if isinstance(chunk, str): + chunks.append(chunk) + elif isinstance(chunk, Usage): + yield chunk + has_usage = True + elif isinstance(chunk, FinishReason): finish = chunk break - elif isinstance(chunk, str): - chunks.append(chunk) else: yield chunk + if not has_usage: + yield Usage(completion_tokens=len(chunks), total_tokens=len(chunks)) + chunks = "".join(chunks) if tools is not None: yield ToolCalls([{ @@ -72,5 +78,6 @@ class ToolSupportProvider(AsyncGeneratorProvider): } }]) yield chunks + if finish is not None: yield finish \ No newline at end of file diff --git a/g4f/tools/run_tools.py b/g4f/tools/run_tools.py index c96963ec..8ecff20c 100644 --- a/g4f/tools/run_tools.py +++ b/g4f/tools/run_tools.py @@ -59,7 +59,7 @@ class ToolHandler: def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]: """Process continue tool requests""" kwargs = {} - if provider not in ("OpenaiAccount", "HuggingFace"): + if provider not in ("OpenaiAccount", "HuggingFaceAPI"): messages = messages.copy() last_line = messages[-1]["content"].strip().splitlines()[-1] content = f"Carry on from this point:\n{last_line}" @@ -84,13 +84,11 @@ class ToolHandler: if new_message_content != message["content"]: has_bucket = True message["content"] = new_message_content - - if has_bucket and isinstance(messages[-1]["content"], str): - if "\nSource: " in messages[-1]["content"]: - if isinstance(messages[-1]["content"], dict): - messages[-1]["content"]["content"] += BUCKET_INSTRUCTIONS - else: - messages[-1]["content"] += BUCKET_INSTRUCTIONS + + last_message_content = messages[-1]["content"] + if has_bucket and isinstance(last_message_content, str): + if "\nSource: " in last_message_content: + messages[-1]["content"] = last_message_content + BUCKET_INSTRUCTIONS return messages @@ -309,9 +307,10 @@ def iter_run_tools( if new_message_content != message["content"]: has_bucket = True message["content"] = new_message_content - if has_bucket and isinstance(messages[-1]["content"], str): - if "\nSource: " in messages[-1]["content"]: - messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS + last_message = messages[-1]["content"] + if has_bucket and isinstance(last_message, str): + if "\nSource: " in last_message: + messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS # Process response chunks thinking_start_time = 0