Add ARTA image provider

Add ToolSupport in PollinationsAI provider
Add default value for model in chat completions
Add Streaming Support for PollinationsAI provider
This commit is contained in:
hlohaus 2025-03-11 02:49:24 +01:00
parent ad59df3011
commit 3e7af90949
14 changed files with 353 additions and 103 deletions

View file

@ -109,15 +109,18 @@ This example shows how to initialize an agent with a specific model (`gpt-4o`) a
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.models import ModelSettings
from g4f.integration.pydantic_ai import patch_infer_model
from g4f.integration.pydantic_ai import AIModel
from g4f.Provider import PollinationsAI
patch_infer_model("your_api_key")
class MyModel(BaseModel):
city: str
country: str
agent = Agent('g4f:Groq:llama3-70b-8192', result_type=MyModel, model_settings=ModelSettings(temperature=0))
nt = Agent(AIModel(
"gpt-4o", # Specify the provider and model
PollinationsAI # Use a supported provider to handle tool-based response formatting
), result_type=MyModel, model_settings=ModelSettings(temperature=0))
if __name__ == '__main__':
result = agent.run_sync('The windy city in the US of A.')
@ -152,7 +155,7 @@ class MyModel(BaseModel):
# Create the agent for a model with tool support (using one tool)
agent = Agent(AIModel(
"PollinationsAI:openai", # Specify the provider and model
"OpenaiChat:gpt-4o", # Specify the provider and model
ToolSupportProvider # Use ToolSupportProvider to handle tool-based response formatting
), result_type=MyModel, model_settings=ModelSettings(temperature=0))

189
g4f/Provider/ARTA.py Normal file
View file

@ -0,0 +1,189 @@
from __future__ import annotations
import os
import time
import json
from pathlib import Path
from aiohttp import ClientSession
import asyncio
from ..typing import AsyncResult, Messages
from ..providers.response import ImageResponse, Reasoning
from ..errors import ResponseError
from ..cookies import get_cookies_dir
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_image_prompt
class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://img-gen-prod.ai-arta.com"
auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"
status_check_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image/{record_id}/status"
working = True
default_model = "Flux"
default_image_model = default_model
model_aliases = {
"flux": "Flux",
"medieval": "Medieval",
"vincent_van_gogh": "Vincent Van Gogh",
"f_dev": "F Dev",
"low_poly": "Low Poly",
"dreamshaper_xl": "Dreamshaper-xl",
"anima_pencil_xl": "Anima-pencil-xl",
"biomech": "Biomech",
"trash_polka": "Trash Polka",
"no_style": "No Style",
"cheyenne_xl": "Cheyenne-xl",
"chicano": "Chicano",
"embroidery_tattoo": "Embroidery tattoo",
"red_and_black": "Red and Black",
"fantasy_art": "Fantasy Art",
"watercolor": "Watercolor",
"dotwork": "Dotwork",
"old_school_colored": "Old school colored",
"realistic_tattoo": "Realistic tattoo",
"japanese_2": "Japanese_2",
"realistic_stock_xl": "Realistic-stock-xl",
"f_pro": "F Pro",
"revanimated": "RevAnimated",
"katayama_mix_xl": "Katayama-mix-xl",
"sdxl_l": "SDXL L",
"cor_epica_xl": "Cor-epica-xl",
"anime_tattoo": "Anime tattoo",
"new_school": "New School",
"death_metal": "Death metal",
"old_school": "Old School",
"juggernaut_xl": "Juggernaut-xl",
"photographic": "Photographic",
"sdxl_1_0": "SDXL 1.0",
"graffiti": "Graffiti",
"mini_tattoo": "Mini tattoo",
"surrealism": "Surrealism",
"neo_traditional": "Neo-traditional",
"on_limbs_black": "On limbs black",
"yamers_realistic_xl": "Yamers-realistic-xl",
"pony_xl": "Pony-xl",
"playground_xl": "Playground-xl",
"anything_xl": "Anything-xl",
"flame_design": "Flame design",
"kawaii": "Kawaii",
"cinematic_art": "Cinematic Art",
"professional": "Professional",
"flux_black_ink": "Flux Black Ink"
}
image_models = [*model_aliases.keys()]
models = image_models
@classmethod
def get_auth_file(cls):
path = Path(get_cookies_dir())
path.mkdir(exist_ok=True)
filename = f"auth_{cls.__name__}.json"
return path / filename
@classmethod
async def create_token(cls, path: Path, proxy: str | None = None):
async with ClientSession() as session:
# Step 1: Generate Authentication Token
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
auth_data = await auth_response.json()
auth_token = auth_data.get("idToken")
#refresh_token = auth_data.get("refreshToken")
if not auth_token:
raise ResponseError("Failed to obtain authentication token.")
json.dump(auth_data, path.open("w"))
return auth_data
@classmethod
async def refresh_token(cls, refresh_token: str, proxy: str = None) -> tuple[str, str]:
async with ClientSession() as session:
payload = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
response_data = await response.json()
return response_data.get("id_token"), response_data.get("refresh_token")
@classmethod
async def read_and_refresh_token(cls, proxy: str | None = None) -> str:
path = cls.get_auth_file()
if path.is_file():
auth_data = json.load(path.open("rb"))
diff = time.time() - os.path.getmtime(path)
expiresIn = int(auth_data.get("expiresIn"))
if diff < expiresIn:
if diff > expiresIn / 2:
auth_data["idToken"], auth_data["refreshToken"] = await cls.refresh_token(auth_data.get("refreshToken"), proxy)
json.dump(auth_data, path.open("w"))
return auth_data
return await cls.create_token(path, proxy)
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
prompt: str = None,
negative_prompt: str = "blurry, deformed hands, ugly",
images_num: int = 1,
guidance_scale: int = 7,
num_inference_steps: int = 30,
aspect_ratio: str = "1:1",
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
prompt = format_image_prompt(messages, prompt)
# Step 1: Get Authentication Token
auth_data = await cls.read_and_refresh_token(proxy)
async with ClientSession() as session:
# Step 2: Generate Images
image_payload = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"style": model,
"images_num": str(images_num),
"cfg_scale": str(guidance_scale),
"steps": str(num_inference_steps),
"aspect_ratio": aspect_ratio,
}
headers = {
"Authorization": auth_data.get("idToken"),
}
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
image_data = await image_response.json()
record_id = image_data.get("record_id")
if not record_id:
raise ResponseError(f"Failed to initiate image generation: {image_data}")
# Step 3: Check Generation Status
status_url = cls.status_check_url.format(record_id=record_id)
counter = 0
while True:
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
status_data = await status_response.json()
status = status_data.get("status")
if status == "DONE":
image_urls = [image["url"] for image in status_data.get("response", [])]
yield Reasoning(status="Finished")
yield ImageResponse(images=image_urls, alt=prompt)
return
elif status in ("IN_QUEUE", "IN_PROGRESS"):
yield Reasoning(status=("Waiting" if status == "IN_QUEUE" else "Generating") + "." * counter)
await asyncio.sleep(5) # Poll every 5 seconds
counter += 1
if counter > 3:
counter = 0
else:
raise ResponseError(f"Image generation failed with status: {status}")

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
import random
import requests
from urllib.parse import quote_plus
@ -13,7 +14,7 @@ from ..image import to_data_uri
from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls
from .. import debug
DEFAULT_HEADERS = {
@ -52,7 +53,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
text_models = [default_model]
image_models = [default_image_model]
extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"]
vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini"]
vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini", "openai", "openai-large"]
extra_text_models = vision_models
_models_loaded = False
model_aliases = {
@ -138,6 +139,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
stream: bool = False,
proxy: str = None,
prompt: str = None,
width: int = 1024,
@ -154,6 +156,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
frequency_penalty: float = None,
response_format: Optional[dict] = None,
cache: bool = False,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
**kwargs
) -> AsyncResult:
cls.get_models()
@ -193,6 +196,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
response_format=response_format,
seed=seed,
cache=cache,
stream=stream,
extra_parameters=extra_parameters,
**kwargs
):
yield result
@ -246,7 +252,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
frequency_penalty: float,
response_format: Optional[dict],
seed: Optional[int],
cache: bool
cache: bool,
stream: bool,
extra_parameters: list[str],
**kwargs
) -> AsyncResult:
if not cache and seed is None:
seed = random.randint(9999, 99999999)
@ -267,6 +276,13 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
if model in cls.audio_models or stream:
#data["voice"] = random.choice(cls.audio_models[model])
url = cls.text_api_endpoint
stream = False
else:
url = cls.openai_endpoint
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
data = filter_none(**{
"messages": messages,
"model": model,
@ -275,17 +291,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"jsonMode": json_mode,
"stream": False,
"stream": stream,
"seed": seed,
"cache": cache
"cache": cache,
**extra_parameters
})
if "gemini" in model:
data.pop("seed")
if model in cls.audio_models:
#data["voice"] = random.choice(cls.audio_models[model])
url = cls.text_api_endpoint
else:
url = cls.openai_endpoint
async with session.post(url, json=data) as response:
await raise_for_status(response)
if response.headers["content-type"] == "audio/mpeg":
@ -294,11 +304,31 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
elif response.headers["content-type"].startswith("text/plain"):
yield await response.text()
return
elif response.headers["content-type"].startswith("text/event-stream"):
async for line in response.content:
if line.startswith(b"data: "):
if line[6:].startswith(b"[DONE]"):
break
result = json.loads(line[6:])
choice = result.get("choices", [{}])[0]
content = choice.get("delta", {}).get("content")
if content:
yield content
if "usage" in result:
yield Usage(**result["usage"])
finish_reason = choice.get("finish_reason")
if finish_reason:
yield FinishReason(finish_reason)
return
result = await response.json()
choice = result["choices"][0]
message = choice.get("message", {})
content = message.get("content", "")
if "tool_calls" in message:
yield ToolCalls(message["tool_calls"])
if content is not None:
if "</think>" in content and "<think>" not in content:
yield "<think>"

View file

@ -15,6 +15,7 @@ from .mini_max import HailuoAI, MiniMax
from .template import OpenaiTemplate, BackendApi
from .AllenAI import AllenAI
from .ARTA import ARTA
from .Blackbox import Blackbox
from .ChatGLM import ChatGLM
from .ChatGpt import ChatGpt

View file

@ -623,8 +623,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
page = await browser.get(cls.url)
user_agent = await page.evaluate("window.navigator.userAgent")
await page.select("textarea.text-token-text-primary", 240)
await page.evaluate("document.querySelector('textarea.text-token-text-primary').value = 'Hello'")
await page.select("#prompt-textarea", 240)
await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'")
await page.select("[data-testid=\"send-button\"]", 30)
await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()")
while True:
body = await page.evaluate("JSON.stringify(window.__remixContext)")

View file

@ -276,7 +276,7 @@ class Completions:
def create(
self,
messages: Messages,
model: str,
model: str = "",
provider: Optional[ProviderType] = None,
stream: Optional[bool] = False,
proxy: Optional[str] = None,
@ -330,7 +330,7 @@ class Completions:
def stream(
self,
messages: Messages,
model: str,
model: str = "",
**kwargs
) -> IterResponse:
return self.create(messages, model, stream=True, **kwargs)
@ -564,7 +564,7 @@ class AsyncCompletions:
def create(
self,
messages: Messages,
model: str,
model: str = "",
provider: Optional[ProviderType] = None,
stream: Optional[bool] = False,
proxy: Optional[str] = None,
@ -619,7 +619,7 @@ class AsyncCompletions:
def stream(
self,
messages: Messages,
model: str,
model: str = "",
**kwargs
) -> AsyncIterator[ChatCompletionChunk]:
return self.create(messages, model, stream=True, **kwargs)

View file

@ -112,7 +112,7 @@ body:not(.white) a:visited{
.new_version {
position: absolute;
right: 0;
left: 0;
top: 0;
padding: 10px;
font-weight: 500;
@ -143,6 +143,7 @@ body:not(.white) a:visited{
.conversation {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
gap: 5px;
@ -238,8 +239,7 @@ body:not(.white) a:visited{
#close_provider_forms {
max-width: 210px;
margin-left: auto;
margin-right: 8px;
margin-left: 12px;
margin-top: 12px;
}
@ -1584,19 +1584,6 @@ form .field.saved .fa-xmark {
}
}
/* Basic adaptation */
.row {
flex-wrap: wrap;
gap: 10px;
}
.conversations, .settings, .conversation {
flex: 1 1 300px;
min-width: 0;
height: 100%;
}
/* Media queries for mobile devices */
@media (max-width: 768px) {
.row {
@ -1608,11 +1595,6 @@ form .field.saved .fa-xmark {
max-width: 100%;
margin: 0;
}
.conversation {
order: -1;
min-height: 80vh;
}
}
@media (max-width: 480px) {

View file

@ -259,6 +259,10 @@ function register_message_images() {
const register_message_buttons = async () => {
message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const provider_forms = document.querySelector(".provider_forms");
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
const provider_link = el.querySelector("a");
@ -279,6 +283,10 @@ const register_message_buttons = async () => {
});
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el);
await remove_message(window.conversation_id, message_el.dataset.index);
message_el.remove();
@ -286,6 +294,10 @@ const register_message_buttons = async () => {
}));
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
let message_el = get_message_el(el);
let response = await fetch(message_el.dataset.object_url);
let copyText = await response.text();
@ -304,6 +316,10 @@ const register_message_buttons = async () => {
}))
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const elem = window.document.createElement('a');
let filename = `chat ${new Date().toLocaleString()}.txt`.replaceAll(":", "-");
const conversation = await get_conversation(window.conversation_id);
@ -323,6 +339,10 @@ const register_message_buttons = async () => {
}))
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el);
let audio;
if (message_el.dataset.synthesize_url) {
@ -344,6 +364,10 @@ const register_message_buttons = async () => {
}));
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el);
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
@ -351,6 +375,10 @@ const register_message_buttons = async () => {
}));
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
if (!el.disabled) {
el.disabled = true;
const message_el = get_message_el(el);
@ -361,11 +389,19 @@ const register_message_buttons = async () => {
));
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const text = get_message_el(el).innerText;
window.open(`https://wa.me/?text=${encodeURIComponent(text)}`, '_blank');
}));
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
const message_el = get_message_el(el);
el.classList.add("clicked");
message_box.scrollTop = 0;
@ -378,6 +414,10 @@ const register_message_buttons = async () => {
}));
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => el.addEventListener("click", async () => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
let text_el = el.parentElement.querySelector(".reasoning_text");
if (text_el) {
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
@ -569,9 +609,9 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
}
// Remove history, only add new user messages
let filtered_messages = [];
// The message_index is null on count total tokens
if (document.getElementById('history')?.checked && do_filter && message_index != null) {
if (!do_continue && document.getElementById('history')?.checked && do_filter && message_index != null) {
let filtered_messages = [];
while (last_message = messages.pop()) {
if (last_message["role"] == "user") {
filtered_messages.push(last_message);
@ -630,9 +670,9 @@ async function load_provider_parameters(provider) {
form_el.id = form_id;
form_el.classList.add("hidden");
appStorage.setItem(form_el.id, JSON.stringify(parameters_storage[provider]));
let old_form = message_box.querySelector(`#${provider}-form`);
let old_form = document.getElementById(form_id);
if (old_form) {
provider_forms.removeChild(old_form);
old_form.remove();
}
Object.entries(parameters_storage[provider]).forEach(([key, value]) => {
let el_id = `${provider}-${key}`;
@ -649,7 +689,7 @@ async function load_provider_parameters(provider) {
saved_value = value;
}
field_el.innerHTML = `<span class="label">${key}:</span>
<input type="checkbox" id="${el_id}" name="${provider}[${key}]">
<input type="checkbox" id="${el_id}" name="${key}">
<label for="${el_id}" class="toogle" title=""></label>
<i class="fa-solid fa-xmark"></i>`;
form_el.appendChild(field_el);
@ -679,15 +719,15 @@ async function load_provider_parameters(provider) {
placeholder = value == null ? "null" : value;
}
field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`;
if (Number.isInteger(value) && value != 1) {
max = value >= 4096 ? 8192 : 4096;
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
if (Number.isInteger(value)) {
max = value == 42 || value >= 4096 ? 8192 : value >= 100 ? 4096 : value == 1 ? 10 : 100;
field_el.innerHTML += `<input type="range" id="${el_id}" name="${key}" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
} else if (typeof value == "number") {
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<input type="range" id="${el_id}" name="${key}" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
} else {
field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`;
field_el.innerHTML += `<textarea id="${el_id}" name="${key}"></textarea>`;
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
input_el = field_el.querySelector("textarea");
if (value != null) {
@ -723,6 +763,7 @@ async function load_provider_parameters(provider) {
input_el = field_el.querySelector("input");
input_el.dataset.value = value;
input_el.value = saved_value;
input_el.nextElementSibling.value = input_el.value;
input_el.oninput = () => {
input_el.nextElementSibling.value = input_el.value;
field_el.classList.add("saved");
@ -1008,6 +1049,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
}
await safe_remove_cancel_button();
await register_message_images();
await register_message_buttons();
await load_conversations();
regenerate_button.classList.remove("regenerate-hidden");
}
@ -1035,6 +1077,18 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
}
}
const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value);
let extra_parameters = {};
document.getElementById(`${provider}-form`)?.querySelectorAll(".saved input, .saved textarea").forEach(async (el) => {
let value = el.type == "checkbox" ? el.checked : el.value;
extra_parameters[el.name] = value;
if (el.type == "textarea") {
try {
extra_parameters[el.name] = await JSON.parse(value);
} catch (e) {
}
}
});
console.log(extra_parameters);
await api("conversation", {
id: message_id,
conversation_id: window.conversation_id,
@ -1048,6 +1102,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
api_key: api_key,
api_base: api_base,
ignored: ignored,
...extra_parameters
}, Object.values(image_storage), message_id, scroll, finish_message);
} catch (e) {
console.error(e);

View file

@ -89,25 +89,17 @@ class Api:
ensure_images_dir()
return send_from_directory(os.path.abspath(images_dir), name)
def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
def _prepare_conversation_kwargs(self, json_data: dict):
kwargs = {**json_data}
model = json_data.get('model')
provider = json_data.get('provider')
messages = json_data.get('messages')
api_key = json_data.get("api_key")
if api_key:
kwargs["api_key"] = api_key
api_base = json_data.get("api_base")
if api_base:
kwargs["api_base"] = api_base
kwargs["tool_calls"] = [{
"function": {
"name": "bucket_tool"
},
"type": "function"
}]
web_search = json_data.get('web_search')
if web_search:
kwargs["web_search"] = web_search
action = json_data.get('action')
if action == "continue":
kwargs["tool_calls"].append({
@ -117,19 +109,13 @@ class Api:
"type": "function"
})
conversation = json_data.get("conversation")
if conversation is not None:
if isinstance(conversation, dict):
kwargs["conversation"] = JsonConversation(**conversation)
else:
conversation_id = json_data.get("conversation_id")
if conversation_id and provider:
if provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[provider][conversation_id]
if json_data.get("ignored"):
kwargs["ignored"] = json_data["ignored"]
if json_data.get("action"):
kwargs["action"] = json_data["action"]
return {
"model": model,
"provider": provider,

View file

@ -106,17 +106,16 @@ class Backend_Api(Api):
Returns:
Response: A Flask response object for streaming.
"""
kwargs = {}
if "json" in request.form:
json_data = json.loads(request.form['json'])
else:
json_data = request.json
if "files" in request.files:
images = []
for file in request.files.getlist('files'):
if file.filename != '' and is_allowed_extension(file.filename):
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
kwargs['images'] = images
if "json" in request.form:
json_data = json.loads(request.form['json'])
else:
json_data = request.json
json_data['images'] = images
if app.demo and not json_data.get("provider"):
model = json_data.get("model")
@ -126,9 +125,7 @@ class Backend_Api(Api):
if not model or model == "default":
json_data["model"] = models.demo_models["default"][0].name
json_data["provider"] = random.choice(models.demo_models["default"][1])
if "images" in json_data:
kwargs["images"] = json_data["images"]
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
kwargs = self._prepare_conversation_kwargs(json_data)
return self.app.response_class(
self._create_response_stream(
kwargs,

View file

@ -21,12 +21,12 @@ from .api import Api
class JsApi(Api):
def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None, **kwargs) -> Iterator:
def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None) -> Iterator:
window = webview.windows[0]
if hasattr(self, "image") and self.image is not None:
kwargs["image"] = open(self.image, "rb")
options["image"] = open(self.image, "rb")
for message in self._create_response_stream(
self._prepare_conversation_kwargs(options, kwargs),
self._prepare_conversation_kwargs(options),
options.get("conversation_id"),
options.get('provider')
):

View file

@ -34,7 +34,7 @@ SAFE_PARAMETERS = [
"api_key", "api_base", "seed", "width", "height",
"proof_token", "max_retries", "web_search",
"guidance_scale", "num_inference_steps", "randomize_seed",
"safe", "enhance", "private",
"safe", "enhance", "private", "aspect_ratio", "images_num",
]
BASIC_PARAMETERS = {

View file

@ -3,11 +3,10 @@ from __future__ import annotations
import json
from ..typing import AsyncResult, Messages, ImagesType
from ..providers.asyncio import to_async_iterator
from ..client.service import get_model_and_provider
from ..client.helper import filter_json
from .base_provider import AsyncGeneratorProvider
from .response import ToolCalls, FinishReason
from .response import ToolCalls, FinishReason, Usage
class ToolSupportProvider(AsyncGeneratorProvider):
working = True
@ -45,6 +44,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
finish = None
chunks = []
has_usage = False
async for chunk in provider.get_async_create_function()(
model,
messages,
@ -53,14 +53,20 @@ class ToolSupportProvider(AsyncGeneratorProvider):
response_format=response_format,
**kwargs
):
if isinstance(chunk, FinishReason):
if isinstance(chunk, str):
chunks.append(chunk)
elif isinstance(chunk, Usage):
yield chunk
has_usage = True
elif isinstance(chunk, FinishReason):
finish = chunk
break
elif isinstance(chunk, str):
chunks.append(chunk)
else:
yield chunk
if not has_usage:
yield Usage(completion_tokens=len(chunks), total_tokens=len(chunks))
chunks = "".join(chunks)
if tools is not None:
yield ToolCalls([{
@ -72,5 +78,6 @@ class ToolSupportProvider(AsyncGeneratorProvider):
}
}])
yield chunks
if finish is not None:
yield finish

View file

@ -59,7 +59,7 @@ class ToolHandler:
def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]:
"""Process continue tool requests"""
kwargs = {}
if provider not in ("OpenaiAccount", "HuggingFace"):
if provider not in ("OpenaiAccount", "HuggingFaceAPI"):
messages = messages.copy()
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Carry on from this point:\n{last_line}"
@ -85,12 +85,10 @@ class ToolHandler:
has_bucket = True
message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str):
if "\nSource: " in messages[-1]["content"]:
if isinstance(messages[-1]["content"], dict):
messages[-1]["content"]["content"] += BUCKET_INSTRUCTIONS
else:
messages[-1]["content"] += BUCKET_INSTRUCTIONS
last_message_content = messages[-1]["content"]
if has_bucket and isinstance(last_message_content, str):
if "\nSource: " in last_message_content:
messages[-1]["content"] = last_message_content + BUCKET_INSTRUCTIONS
return messages
@ -309,9 +307,10 @@ def iter_run_tools(
if new_message_content != message["content"]:
has_bucket = True
message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str):
if "\nSource: " in messages[-1]["content"]:
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
last_message = messages[-1]["content"]
if has_bucket and isinstance(last_message, str):
if "\nSource: " in last_message:
messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
# Process response chunks
thinking_start_time = 0