Add ARTA image provider

Add ToolSupport in PollinationsAI provider
Add default value for model in chat completions
Add Streaming Support for PollinationsAI provider
This commit is contained in:
hlohaus 2025-03-11 02:49:24 +01:00
parent ad59df3011
commit 3e7af90949
14 changed files with 353 additions and 103 deletions

View file

@ -109,15 +109,18 @@ This example shows how to initialize an agent with a specific model (`gpt-4o`) a
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_ai import Agent from pydantic_ai import Agent
from pydantic_ai.models import ModelSettings from pydantic_ai.models import ModelSettings
from g4f.integration.pydantic_ai import patch_infer_model from g4f.integration.pydantic_ai import AIModel
from g4f.Provider import PollinationsAI
patch_infer_model("your_api_key")
class MyModel(BaseModel): class MyModel(BaseModel):
city: str city: str
country: str country: str
agent = Agent('g4f:Groq:llama3-70b-8192', result_type=MyModel, model_settings=ModelSettings(temperature=0)) nt = Agent(AIModel(
"gpt-4o", # Specify the provider and model
PollinationsAI # Use a supported provider to handle tool-based response formatting
), result_type=MyModel, model_settings=ModelSettings(temperature=0))
if __name__ == '__main__': if __name__ == '__main__':
result = agent.run_sync('The windy city in the US of A.') result = agent.run_sync('The windy city in the US of A.')
@ -152,7 +155,7 @@ class MyModel(BaseModel):
# Create the agent for a model with tool support (using one tool) # Create the agent for a model with tool support (using one tool)
agent = Agent(AIModel( agent = Agent(AIModel(
"PollinationsAI:openai", # Specify the provider and model "OpenaiChat:gpt-4o", # Specify the provider and model
ToolSupportProvider # Use ToolSupportProvider to handle tool-based response formatting ToolSupportProvider # Use ToolSupportProvider to handle tool-based response formatting
), result_type=MyModel, model_settings=ModelSettings(temperature=0)) ), result_type=MyModel, model_settings=ModelSettings(temperature=0))

189
g4f/Provider/ARTA.py Normal file
View file

@ -0,0 +1,189 @@
from __future__ import annotations
import os
import time
import json
from pathlib import Path
from aiohttp import ClientSession
import asyncio
from ..typing import AsyncResult, Messages
from ..providers.response import ImageResponse, Reasoning
from ..errors import ResponseError
from ..cookies import get_cookies_dir
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_image_prompt
class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://img-gen-prod.ai-arta.com"
auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"
status_check_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image/{record_id}/status"
working = True
default_model = "Flux"
default_image_model = default_model
model_aliases = {
"flux": "Flux",
"medieval": "Medieval",
"vincent_van_gogh": "Vincent Van Gogh",
"f_dev": "F Dev",
"low_poly": "Low Poly",
"dreamshaper_xl": "Dreamshaper-xl",
"anima_pencil_xl": "Anima-pencil-xl",
"biomech": "Biomech",
"trash_polka": "Trash Polka",
"no_style": "No Style",
"cheyenne_xl": "Cheyenne-xl",
"chicano": "Chicano",
"embroidery_tattoo": "Embroidery tattoo",
"red_and_black": "Red and Black",
"fantasy_art": "Fantasy Art",
"watercolor": "Watercolor",
"dotwork": "Dotwork",
"old_school_colored": "Old school colored",
"realistic_tattoo": "Realistic tattoo",
"japanese_2": "Japanese_2",
"realistic_stock_xl": "Realistic-stock-xl",
"f_pro": "F Pro",
"revanimated": "RevAnimated",
"katayama_mix_xl": "Katayama-mix-xl",
"sdxl_l": "SDXL L",
"cor_epica_xl": "Cor-epica-xl",
"anime_tattoo": "Anime tattoo",
"new_school": "New School",
"death_metal": "Death metal",
"old_school": "Old School",
"juggernaut_xl": "Juggernaut-xl",
"photographic": "Photographic",
"sdxl_1_0": "SDXL 1.0",
"graffiti": "Graffiti",
"mini_tattoo": "Mini tattoo",
"surrealism": "Surrealism",
"neo_traditional": "Neo-traditional",
"on_limbs_black": "On limbs black",
"yamers_realistic_xl": "Yamers-realistic-xl",
"pony_xl": "Pony-xl",
"playground_xl": "Playground-xl",
"anything_xl": "Anything-xl",
"flame_design": "Flame design",
"kawaii": "Kawaii",
"cinematic_art": "Cinematic Art",
"professional": "Professional",
"flux_black_ink": "Flux Black Ink"
}
image_models = [*model_aliases.keys()]
models = image_models
@classmethod
def get_auth_file(cls):
path = Path(get_cookies_dir())
path.mkdir(exist_ok=True)
filename = f"auth_{cls.__name__}.json"
return path / filename
@classmethod
async def create_token(cls, path: Path, proxy: str | None = None):
async with ClientSession() as session:
# Step 1: Generate Authentication Token
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
auth_data = await auth_response.json()
auth_token = auth_data.get("idToken")
#refresh_token = auth_data.get("refreshToken")
if not auth_token:
raise ResponseError("Failed to obtain authentication token.")
json.dump(auth_data, path.open("w"))
return auth_data
@classmethod
async def refresh_token(cls, refresh_token: str, proxy: str = None) -> tuple[str, str]:
async with ClientSession() as session:
payload = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
response_data = await response.json()
return response_data.get("id_token"), response_data.get("refresh_token")
@classmethod
async def read_and_refresh_token(cls, proxy: str | None = None) -> str:
path = cls.get_auth_file()
if path.is_file():
auth_data = json.load(path.open("rb"))
diff = time.time() - os.path.getmtime(path)
expiresIn = int(auth_data.get("expiresIn"))
if diff < expiresIn:
if diff > expiresIn / 2:
auth_data["idToken"], auth_data["refreshToken"] = await cls.refresh_token(auth_data.get("refreshToken"), proxy)
json.dump(auth_data, path.open("w"))
return auth_data
return await cls.create_token(path, proxy)
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
prompt: str = None,
negative_prompt: str = "blurry, deformed hands, ugly",
images_num: int = 1,
guidance_scale: int = 7,
num_inference_steps: int = 30,
aspect_ratio: str = "1:1",
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
prompt = format_image_prompt(messages, prompt)
# Step 1: Get Authentication Token
auth_data = await cls.read_and_refresh_token(proxy)
async with ClientSession() as session:
# Step 2: Generate Images
image_payload = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"style": model,
"images_num": str(images_num),
"cfg_scale": str(guidance_scale),
"steps": str(num_inference_steps),
"aspect_ratio": aspect_ratio,
}
headers = {
"Authorization": auth_data.get("idToken"),
}
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
image_data = await image_response.json()
record_id = image_data.get("record_id")
if not record_id:
raise ResponseError(f"Failed to initiate image generation: {image_data}")
# Step 3: Check Generation Status
status_url = cls.status_check_url.format(record_id=record_id)
counter = 0
while True:
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
status_data = await status_response.json()
status = status_data.get("status")
if status == "DONE":
image_urls = [image["url"] for image in status_data.get("response", [])]
yield Reasoning(status="Finished")
yield ImageResponse(images=image_urls, alt=prompt)
return
elif status in ("IN_QUEUE", "IN_PROGRESS"):
yield Reasoning(status=("Waiting" if status == "IN_QUEUE" else "Generating") + "." * counter)
await asyncio.sleep(5) # Poll every 5 seconds
counter += 1
if counter > 3:
counter = 0
else:
raise ResponseError(f"Image generation failed with status: {status}")

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
import random import random
import requests import requests
from urllib.parse import quote_plus from urllib.parse import quote_plus
@ -13,7 +14,7 @@ from ..image import to_data_uri
from ..errors import ModelNotFoundError from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls
from .. import debug from .. import debug
DEFAULT_HEADERS = { DEFAULT_HEADERS = {
@ -52,7 +53,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
text_models = [default_model] text_models = [default_model]
image_models = [default_image_model] image_models = [default_image_model]
extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"] extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"]
vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini"] vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini", "openai", "openai-large"]
extra_text_models = vision_models extra_text_models = vision_models
_models_loaded = False _models_loaded = False
model_aliases = { model_aliases = {
@ -138,6 +139,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool = False,
proxy: str = None, proxy: str = None,
prompt: str = None, prompt: str = None,
width: int = 1024, width: int = 1024,
@ -154,6 +156,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
frequency_penalty: float = None, frequency_penalty: float = None,
response_format: Optional[dict] = None, response_format: Optional[dict] = None,
cache: bool = False, cache: bool = False,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
cls.get_models() cls.get_models()
@ -193,6 +196,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
response_format=response_format, response_format=response_format,
seed=seed, seed=seed,
cache=cache, cache=cache,
stream=stream,
extra_parameters=extra_parameters,
**kwargs
): ):
yield result yield result
@ -246,7 +252,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
frequency_penalty: float, frequency_penalty: float,
response_format: Optional[dict], response_format: Optional[dict],
seed: Optional[int], seed: Optional[int],
cache: bool cache: bool,
stream: bool,
extra_parameters: list[str],
**kwargs
) -> AsyncResult: ) -> AsyncResult:
if not cache and seed is None: if not cache and seed is None:
seed = random.randint(9999, 99999999) seed = random.randint(9999, 99999999)
@ -267,6 +276,13 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
messages[-1] = last_message messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session: async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
if model in cls.audio_models or stream:
#data["voice"] = random.choice(cls.audio_models[model])
url = cls.text_api_endpoint
stream = False
else:
url = cls.openai_endpoint
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
data = filter_none(**{ data = filter_none(**{
"messages": messages, "messages": messages,
"model": model, "model": model,
@ -275,17 +291,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"top_p": top_p, "top_p": top_p,
"frequency_penalty": frequency_penalty, "frequency_penalty": frequency_penalty,
"jsonMode": json_mode, "jsonMode": json_mode,
"stream": False, "stream": stream,
"seed": seed, "seed": seed,
"cache": cache "cache": cache,
**extra_parameters
}) })
if "gemini" in model:
data.pop("seed")
if model in cls.audio_models:
#data["voice"] = random.choice(cls.audio_models[model])
url = cls.text_api_endpoint
else:
url = cls.openai_endpoint
async with session.post(url, json=data) as response: async with session.post(url, json=data) as response:
await raise_for_status(response) await raise_for_status(response)
if response.headers["content-type"] == "audio/mpeg": if response.headers["content-type"] == "audio/mpeg":
@ -294,11 +304,31 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
elif response.headers["content-type"].startswith("text/plain"): elif response.headers["content-type"].startswith("text/plain"):
yield await response.text() yield await response.text()
return return
elif response.headers["content-type"].startswith("text/event-stream"):
async for line in response.content:
if line.startswith(b"data: "):
if line[6:].startswith(b"[DONE]"):
break
result = json.loads(line[6:])
choice = result.get("choices", [{}])[0]
content = choice.get("delta", {}).get("content")
if content:
yield content
if "usage" in result:
yield Usage(**result["usage"])
finish_reason = choice.get("finish_reason")
if finish_reason:
yield FinishReason(finish_reason)
return
result = await response.json() result = await response.json()
choice = result["choices"][0] choice = result["choices"][0]
message = choice.get("message", {}) message = choice.get("message", {})
content = message.get("content", "") content = message.get("content", "")
if "tool_calls" in message:
yield ToolCalls(message["tool_calls"])
if content is not None:
if "</think>" in content and "<think>" not in content: if "</think>" in content and "<think>" not in content:
yield "<think>" yield "<think>"

View file

@ -15,6 +15,7 @@ from .mini_max import HailuoAI, MiniMax
from .template import OpenaiTemplate, BackendApi from .template import OpenaiTemplate, BackendApi
from .AllenAI import AllenAI from .AllenAI import AllenAI
from .ARTA import ARTA
from .Blackbox import Blackbox from .Blackbox import Blackbox
from .ChatGLM import ChatGLM from .ChatGLM import ChatGLM
from .ChatGpt import ChatGpt from .ChatGpt import ChatGpt

View file

@ -623,8 +623,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request) page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
page = await browser.get(cls.url) page = await browser.get(cls.url)
user_agent = await page.evaluate("window.navigator.userAgent") user_agent = await page.evaluate("window.navigator.userAgent")
await page.select("textarea.text-token-text-primary", 240) await page.select("#prompt-textarea", 240)
await page.evaluate("document.querySelector('textarea.text-token-text-primary').value = 'Hello'") await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'")
await page.select("[data-testid=\"send-button\"]", 30)
await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()") await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()")
while True: while True:
body = await page.evaluate("JSON.stringify(window.__remixContext)") body = await page.evaluate("JSON.stringify(window.__remixContext)")

View file

@ -276,7 +276,7 @@ class Completions:
def create( def create(
self, self,
messages: Messages, messages: Messages,
model: str, model: str = "",
provider: Optional[ProviderType] = None, provider: Optional[ProviderType] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -330,7 +330,7 @@ class Completions:
def stream( def stream(
self, self,
messages: Messages, messages: Messages,
model: str, model: str = "",
**kwargs **kwargs
) -> IterResponse: ) -> IterResponse:
return self.create(messages, model, stream=True, **kwargs) return self.create(messages, model, stream=True, **kwargs)
@ -564,7 +564,7 @@ class AsyncCompletions:
def create( def create(
self, self,
messages: Messages, messages: Messages,
model: str, model: str = "",
provider: Optional[ProviderType] = None, provider: Optional[ProviderType] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -619,7 +619,7 @@ class AsyncCompletions:
def stream( def stream(
self, self,
messages: Messages, messages: Messages,
model: str, model: str = "",
**kwargs **kwargs
) -> AsyncIterator[ChatCompletionChunk]: ) -> AsyncIterator[ChatCompletionChunk]:
return self.create(messages, model, stream=True, **kwargs) return self.create(messages, model, stream=True, **kwargs)

View file

@ -112,7 +112,7 @@ body:not(.white) a:visited{
.new_version { .new_version {
position: absolute; position: absolute;
right: 0; left: 0;
top: 0; top: 0;
padding: 10px; padding: 10px;
font-weight: 500; font-weight: 500;
@ -143,6 +143,7 @@ body:not(.white) a:visited{
.conversation { .conversation {
width: 100%; width: 100%;
height: 100%;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 5px; gap: 5px;
@ -238,8 +239,7 @@ body:not(.white) a:visited{
#close_provider_forms { #close_provider_forms {
max-width: 210px; max-width: 210px;
margin-left: auto; margin-left: 12px;
margin-right: 8px;
margin-top: 12px; margin-top: 12px;
} }
@ -1584,19 +1584,6 @@ form .field.saved .fa-xmark {
} }
} }
/* Basic adaptation */
.row {
flex-wrap: wrap;
gap: 10px;
}
.conversations, .settings, .conversation {
flex: 1 1 300px;
min-width: 0;
height: 100%;
}
/* Media queries for mobile devices */ /* Media queries for mobile devices */
@media (max-width: 768px) { @media (max-width: 768px) {
.row { .row {
@ -1608,11 +1595,6 @@ form .field.saved .fa-xmark {
max-width: 100%; max-width: 100%;
margin: 0; margin: 0;
} }
.conversation {
order: -1;
min-height: 80vh;
}
} }
@media (max-width: 480px) { @media (max-width: 480px) {

View file

@ -259,6 +259,10 @@ function register_message_images() {
const register_message_buttons = async () => { const register_message_buttons = async () => {
message_box.querySelectorAll(".message .content .provider").forEach(async (el) => { message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const provider_forms = document.querySelector(".provider_forms"); const provider_forms = document.querySelector(".provider_forms");
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`); const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
const provider_link = el.querySelector("a"); const provider_link = el.querySelector("a");
@ -279,6 +283,10 @@ const register_message_buttons = async () => {
}); });
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el); const message_el = get_message_el(el);
await remove_message(window.conversation_id, message_el.dataset.index); await remove_message(window.conversation_id, message_el.dataset.index);
message_el.remove(); message_el.remove();
@ -286,6 +294,10 @@ const register_message_buttons = async () => {
})); }));
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
let message_el = get_message_el(el); let message_el = get_message_el(el);
let response = await fetch(message_el.dataset.object_url); let response = await fetch(message_el.dataset.object_url);
let copyText = await response.text(); let copyText = await response.text();
@ -304,6 +316,10 @@ const register_message_buttons = async () => {
})) }))
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const elem = window.document.createElement('a'); const elem = window.document.createElement('a');
let filename = `chat ${new Date().toLocaleString()}.txt`.replaceAll(":", "-"); let filename = `chat ${new Date().toLocaleString()}.txt`.replaceAll(":", "-");
const conversation = await get_conversation(window.conversation_id); const conversation = await get_conversation(window.conversation_id);
@ -323,6 +339,10 @@ const register_message_buttons = async () => {
})) }))
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el); const message_el = get_message_el(el);
let audio; let audio;
if (message_el.dataset.synthesize_url) { if (message_el.dataset.synthesize_url) {
@ -344,6 +364,10 @@ const register_message_buttons = async () => {
})); }));
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el); const message_el = get_message_el(el);
el.classList.add("clicked"); el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000); setTimeout(() => el.classList.remove("clicked"), 1000);
@ -351,6 +375,10 @@ const register_message_buttons = async () => {
})); }));
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .continue_button").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
if (!el.disabled) { if (!el.disabled) {
el.disabled = true; el.disabled = true;
const message_el = get_message_el(el); const message_el = get_message_el(el);
@ -361,11 +389,19 @@ const register_message_buttons = async () => {
)); ));
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const text = get_message_el(el).innerText; const text = get_message_el(el).innerText;
window.open(`https://wa.me/?text=${encodeURIComponent(text)}`, '_blank'); window.open(`https://wa.me/?text=${encodeURIComponent(text)}`, '_blank');
})); }));
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .fa-print").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el); const message_el = get_message_el(el);
el.classList.add("clicked"); el.classList.add("clicked");
message_box.scrollTop = 0; message_box.scrollTop = 0;
@ -378,6 +414,10 @@ const register_message_buttons = async () => {
})); }));
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => el.addEventListener("click", async () => { message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
let text_el = el.parentElement.querySelector(".reasoning_text"); let text_el = el.parentElement.querySelector(".reasoning_text");
if (text_el) { if (text_el) {
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden"); text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
@ -569,9 +609,9 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
} }
// Remove history, only add new user messages // Remove history, only add new user messages
let filtered_messages = [];
// The message_index is null on count total tokens // The message_index is null on count total tokens
if (document.getElementById('history')?.checked && do_filter && message_index != null) { if (!do_continue && document.getElementById('history')?.checked && do_filter && message_index != null) {
let filtered_messages = [];
while (last_message = messages.pop()) { while (last_message = messages.pop()) {
if (last_message["role"] == "user") { if (last_message["role"] == "user") {
filtered_messages.push(last_message); filtered_messages.push(last_message);
@ -630,9 +670,9 @@ async function load_provider_parameters(provider) {
form_el.id = form_id; form_el.id = form_id;
form_el.classList.add("hidden"); form_el.classList.add("hidden");
appStorage.setItem(form_el.id, JSON.stringify(parameters_storage[provider])); appStorage.setItem(form_el.id, JSON.stringify(parameters_storage[provider]));
let old_form = message_box.querySelector(`#${provider}-form`); let old_form = document.getElementById(form_id);
if (old_form) { if (old_form) {
provider_forms.removeChild(old_form); old_form.remove();
} }
Object.entries(parameters_storage[provider]).forEach(([key, value]) => { Object.entries(parameters_storage[provider]).forEach(([key, value]) => {
let el_id = `${provider}-${key}`; let el_id = `${provider}-${key}`;
@ -649,7 +689,7 @@ async function load_provider_parameters(provider) {
saved_value = value; saved_value = value;
} }
field_el.innerHTML = `<span class="label">${key}:</span> field_el.innerHTML = `<span class="label">${key}:</span>
<input type="checkbox" id="${el_id}" name="${provider}[${key}]"> <input type="checkbox" id="${el_id}" name="${key}">
<label for="${el_id}" class="toogle" title=""></label> <label for="${el_id}" class="toogle" title=""></label>
<i class="fa-solid fa-xmark"></i>`; <i class="fa-solid fa-xmark"></i>`;
form_el.appendChild(field_el); form_el.appendChild(field_el);
@ -679,15 +719,15 @@ async function load_provider_parameters(provider) {
placeholder = value == null ? "null" : value; placeholder = value == null ? "null" : value;
} }
field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`; field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`;
if (Number.isInteger(value) && value != 1) { if (Number.isInteger(value)) {
max = value >= 4096 ? 8192 : 4096; max = value == 42 || value >= 4096 ? 8192 : value >= 100 ? 4096 : value == 1 ? 10 : 100;
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`; field_el.innerHTML += `<input type="range" id="${el_id}" name="${key}" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`; field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
} else if (typeof value == "number") { } else if (typeof value == "number") {
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`; field_el.innerHTML += `<input type="range" id="${el_id}" name="${key}" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`; field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
} else { } else {
field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`; field_el.innerHTML += `<textarea id="${el_id}" name="${key}"></textarea>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`; field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
input_el = field_el.querySelector("textarea"); input_el = field_el.querySelector("textarea");
if (value != null) { if (value != null) {
@ -723,6 +763,7 @@ async function load_provider_parameters(provider) {
input_el = field_el.querySelector("input"); input_el = field_el.querySelector("input");
input_el.dataset.value = value; input_el.dataset.value = value;
input_el.value = saved_value; input_el.value = saved_value;
input_el.nextElementSibling.value = input_el.value;
input_el.oninput = () => { input_el.oninput = () => {
input_el.nextElementSibling.value = input_el.value; input_el.nextElementSibling.value = input_el.value;
field_el.classList.add("saved"); field_el.classList.add("saved");
@ -1008,6 +1049,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} }
await safe_remove_cancel_button(); await safe_remove_cancel_button();
await register_message_images(); await register_message_images();
await register_message_buttons();
await load_conversations(); await load_conversations();
regenerate_button.classList.remove("regenerate-hidden"); regenerate_button.classList.remove("regenerate-hidden");
} }
@ -1035,6 +1077,18 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} }
} }
const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value); const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value);
let extra_parameters = {};
document.getElementById(`${provider}-form`)?.querySelectorAll(".saved input, .saved textarea").forEach(async (el) => {
let value = el.type == "checkbox" ? el.checked : el.value;
extra_parameters[el.name] = value;
if (el.type == "textarea") {
try {
extra_parameters[el.name] = await JSON.parse(value);
} catch (e) {
}
}
});
console.log(extra_parameters);
await api("conversation", { await api("conversation", {
id: message_id, id: message_id,
conversation_id: window.conversation_id, conversation_id: window.conversation_id,
@ -1048,6 +1102,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
api_key: api_key, api_key: api_key,
api_base: api_base, api_base: api_base,
ignored: ignored, ignored: ignored,
...extra_parameters
}, Object.values(image_storage), message_id, scroll, finish_message); }, Object.values(image_storage), message_id, scroll, finish_message);
} catch (e) { } catch (e) {
console.error(e); console.error(e);

View file

@ -89,25 +89,17 @@ class Api:
ensure_images_dir() ensure_images_dir()
return send_from_directory(os.path.abspath(images_dir), name) return send_from_directory(os.path.abspath(images_dir), name)
def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict): def _prepare_conversation_kwargs(self, json_data: dict):
kwargs = {**json_data}
model = json_data.get('model') model = json_data.get('model')
provider = json_data.get('provider') provider = json_data.get('provider')
messages = json_data.get('messages') messages = json_data.get('messages')
api_key = json_data.get("api_key")
if api_key:
kwargs["api_key"] = api_key
api_base = json_data.get("api_base")
if api_base:
kwargs["api_base"] = api_base
kwargs["tool_calls"] = [{ kwargs["tool_calls"] = [{
"function": { "function": {
"name": "bucket_tool" "name": "bucket_tool"
}, },
"type": "function" "type": "function"
}] }]
web_search = json_data.get('web_search')
if web_search:
kwargs["web_search"] = web_search
action = json_data.get('action') action = json_data.get('action')
if action == "continue": if action == "continue":
kwargs["tool_calls"].append({ kwargs["tool_calls"].append({
@ -117,19 +109,13 @@ class Api:
"type": "function" "type": "function"
}) })
conversation = json_data.get("conversation") conversation = json_data.get("conversation")
if conversation is not None: if isinstance(conversation, dict):
kwargs["conversation"] = JsonConversation(**conversation) kwargs["conversation"] = JsonConversation(**conversation)
else: else:
conversation_id = json_data.get("conversation_id") conversation_id = json_data.get("conversation_id")
if conversation_id and provider: if conversation_id and provider:
if provider in conversations and conversation_id in conversations[provider]: if provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[provider][conversation_id] kwargs["conversation"] = conversations[provider][conversation_id]
if json_data.get("ignored"):
kwargs["ignored"] = json_data["ignored"]
if json_data.get("action"):
kwargs["action"] = json_data["action"]
return { return {
"model": model, "model": model,
"provider": provider, "provider": provider,

View file

@ -106,17 +106,16 @@ class Backend_Api(Api):
Returns: Returns:
Response: A Flask response object for streaming. Response: A Flask response object for streaming.
""" """
kwargs = {} if "json" in request.form:
json_data = json.loads(request.form['json'])
else:
json_data = request.json
if "files" in request.files: if "files" in request.files:
images = [] images = []
for file in request.files.getlist('files'): for file in request.files.getlist('files'):
if file.filename != '' and is_allowed_extension(file.filename): if file.filename != '' and is_allowed_extension(file.filename):
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename)) images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
kwargs['images'] = images json_data['images'] = images
if "json" in request.form:
json_data = json.loads(request.form['json'])
else:
json_data = request.json
if app.demo and not json_data.get("provider"): if app.demo and not json_data.get("provider"):
model = json_data.get("model") model = json_data.get("model")
@ -126,9 +125,7 @@ class Backend_Api(Api):
if not model or model == "default": if not model or model == "default":
json_data["model"] = models.demo_models["default"][0].name json_data["model"] = models.demo_models["default"][0].name
json_data["provider"] = random.choice(models.demo_models["default"][1]) json_data["provider"] = random.choice(models.demo_models["default"][1])
if "images" in json_data: kwargs = self._prepare_conversation_kwargs(json_data)
kwargs["images"] = json_data["images"]
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
return self.app.response_class( return self.app.response_class(
self._create_response_stream( self._create_response_stream(
kwargs, kwargs,

View file

@ -21,12 +21,12 @@ from .api import Api
class JsApi(Api): class JsApi(Api):
def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None, **kwargs) -> Iterator: def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None) -> Iterator:
window = webview.windows[0] window = webview.windows[0]
if hasattr(self, "image") and self.image is not None: if hasattr(self, "image") and self.image is not None:
kwargs["image"] = open(self.image, "rb") options["image"] = open(self.image, "rb")
for message in self._create_response_stream( for message in self._create_response_stream(
self._prepare_conversation_kwargs(options, kwargs), self._prepare_conversation_kwargs(options),
options.get("conversation_id"), options.get("conversation_id"),
options.get('provider') options.get('provider')
): ):

View file

@ -34,7 +34,7 @@ SAFE_PARAMETERS = [
"api_key", "api_base", "seed", "width", "height", "api_key", "api_base", "seed", "width", "height",
"proof_token", "max_retries", "web_search", "proof_token", "max_retries", "web_search",
"guidance_scale", "num_inference_steps", "randomize_seed", "guidance_scale", "num_inference_steps", "randomize_seed",
"safe", "enhance", "private", "safe", "enhance", "private", "aspect_ratio", "images_num",
] ]
BASIC_PARAMETERS = { BASIC_PARAMETERS = {

View file

@ -3,11 +3,10 @@ from __future__ import annotations
import json import json
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, ImagesType
from ..providers.asyncio import to_async_iterator
from ..client.service import get_model_and_provider from ..client.service import get_model_and_provider
from ..client.helper import filter_json from ..client.helper import filter_json
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider
from .response import ToolCalls, FinishReason from .response import ToolCalls, FinishReason, Usage
class ToolSupportProvider(AsyncGeneratorProvider): class ToolSupportProvider(AsyncGeneratorProvider):
working = True working = True
@ -45,6 +44,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
finish = None finish = None
chunks = [] chunks = []
has_usage = False
async for chunk in provider.get_async_create_function()( async for chunk in provider.get_async_create_function()(
model, model,
messages, messages,
@ -53,14 +53,20 @@ class ToolSupportProvider(AsyncGeneratorProvider):
response_format=response_format, response_format=response_format,
**kwargs **kwargs
): ):
if isinstance(chunk, FinishReason): if isinstance(chunk, str):
chunks.append(chunk)
elif isinstance(chunk, Usage):
yield chunk
has_usage = True
elif isinstance(chunk, FinishReason):
finish = chunk finish = chunk
break break
elif isinstance(chunk, str):
chunks.append(chunk)
else: else:
yield chunk yield chunk
if not has_usage:
yield Usage(completion_tokens=len(chunks), total_tokens=len(chunks))
chunks = "".join(chunks) chunks = "".join(chunks)
if tools is not None: if tools is not None:
yield ToolCalls([{ yield ToolCalls([{
@ -72,5 +78,6 @@ class ToolSupportProvider(AsyncGeneratorProvider):
} }
}]) }])
yield chunks yield chunks
if finish is not None: if finish is not None:
yield finish yield finish

View file

@ -59,7 +59,7 @@ class ToolHandler:
def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]: def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]:
"""Process continue tool requests""" """Process continue tool requests"""
kwargs = {} kwargs = {}
if provider not in ("OpenaiAccount", "HuggingFace"): if provider not in ("OpenaiAccount", "HuggingFaceAPI"):
messages = messages.copy() messages = messages.copy()
last_line = messages[-1]["content"].strip().splitlines()[-1] last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Carry on from this point:\n{last_line}" content = f"Carry on from this point:\n{last_line}"
@ -85,12 +85,10 @@ class ToolHandler:
has_bucket = True has_bucket = True
message["content"] = new_message_content message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str): last_message_content = messages[-1]["content"]
if "\nSource: " in messages[-1]["content"]: if has_bucket and isinstance(last_message_content, str):
if isinstance(messages[-1]["content"], dict): if "\nSource: " in last_message_content:
messages[-1]["content"]["content"] += BUCKET_INSTRUCTIONS messages[-1]["content"] = last_message_content + BUCKET_INSTRUCTIONS
else:
messages[-1]["content"] += BUCKET_INSTRUCTIONS
return messages return messages
@ -309,9 +307,10 @@ def iter_run_tools(
if new_message_content != message["content"]: if new_message_content != message["content"]:
has_bucket = True has_bucket = True
message["content"] = new_message_content message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str): last_message = messages[-1]["content"]
if "\nSource: " in messages[-1]["content"]: if has_bucket and isinstance(last_message, str):
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS if "\nSource: " in last_message:
messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
# Process response chunks # Process response chunks
thinking_start_time = 0 thinking_start_time = 0