mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Add ARTA image provider
Add ToolSupport in PollinationsAI provider Add default value for model in chat completions Add Streaming Support for PollinationsAI provider
This commit is contained in:
parent
ad59df3011
commit
3e7af90949
14 changed files with 353 additions and 103 deletions
|
|
@ -109,15 +109,18 @@ This example shows how to initialize an agent with a specific model (`gpt-4o`) a
|
|||
from pydantic import BaseModel
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models import ModelSettings
|
||||
from g4f.integration.pydantic_ai import patch_infer_model
|
||||
from g4f.integration.pydantic_ai import AIModel
|
||||
from g4f.Provider import PollinationsAI
|
||||
|
||||
patch_infer_model("your_api_key")
|
||||
|
||||
class MyModel(BaseModel):
|
||||
city: str
|
||||
country: str
|
||||
|
||||
agent = Agent('g4f:Groq:llama3-70b-8192', result_type=MyModel, model_settings=ModelSettings(temperature=0))
|
||||
nt = Agent(AIModel(
|
||||
"gpt-4o", # Specify the provider and model
|
||||
PollinationsAI # Use a supported provider to handle tool-based response formatting
|
||||
), result_type=MyModel, model_settings=ModelSettings(temperature=0))
|
||||
|
||||
if __name__ == '__main__':
|
||||
result = agent.run_sync('The windy city in the US of A.')
|
||||
|
|
@ -152,7 +155,7 @@ class MyModel(BaseModel):
|
|||
|
||||
# Create the agent for a model with tool support (using one tool)
|
||||
agent = Agent(AIModel(
|
||||
"PollinationsAI:openai", # Specify the provider and model
|
||||
"OpenaiChat:gpt-4o", # Specify the provider and model
|
||||
ToolSupportProvider # Use ToolSupportProvider to handle tool-based response formatting
|
||||
), result_type=MyModel, model_settings=ModelSettings(temperature=0))
|
||||
|
||||
|
|
|
|||
189
g4f/Provider/ARTA.py
Normal file
189
g4f/Provider/ARTA.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from aiohttp import ClientSession
|
||||
import asyncio
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..providers.response import ImageResponse, Reasoning
|
||||
from ..errors import ResponseError
|
||||
from ..cookies import get_cookies_dir
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import format_image_prompt
|
||||
|
||||
class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://img-gen-prod.ai-arta.com"
|
||||
auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
|
||||
token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
|
||||
image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"
|
||||
status_check_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image/{record_id}/status"
|
||||
|
||||
working = True
|
||||
|
||||
default_model = "Flux"
|
||||
default_image_model = default_model
|
||||
model_aliases = {
|
||||
"flux": "Flux",
|
||||
"medieval": "Medieval",
|
||||
"vincent_van_gogh": "Vincent Van Gogh",
|
||||
"f_dev": "F Dev",
|
||||
"low_poly": "Low Poly",
|
||||
"dreamshaper_xl": "Dreamshaper-xl",
|
||||
"anima_pencil_xl": "Anima-pencil-xl",
|
||||
"biomech": "Biomech",
|
||||
"trash_polka": "Trash Polka",
|
||||
"no_style": "No Style",
|
||||
"cheyenne_xl": "Cheyenne-xl",
|
||||
"chicano": "Chicano",
|
||||
"embroidery_tattoo": "Embroidery tattoo",
|
||||
"red_and_black": "Red and Black",
|
||||
"fantasy_art": "Fantasy Art",
|
||||
"watercolor": "Watercolor",
|
||||
"dotwork": "Dotwork",
|
||||
"old_school_colored": "Old school colored",
|
||||
"realistic_tattoo": "Realistic tattoo",
|
||||
"japanese_2": "Japanese_2",
|
||||
"realistic_stock_xl": "Realistic-stock-xl",
|
||||
"f_pro": "F Pro",
|
||||
"revanimated": "RevAnimated",
|
||||
"katayama_mix_xl": "Katayama-mix-xl",
|
||||
"sdxl_l": "SDXL L",
|
||||
"cor_epica_xl": "Cor-epica-xl",
|
||||
"anime_tattoo": "Anime tattoo",
|
||||
"new_school": "New School",
|
||||
"death_metal": "Death metal",
|
||||
"old_school": "Old School",
|
||||
"juggernaut_xl": "Juggernaut-xl",
|
||||
"photographic": "Photographic",
|
||||
"sdxl_1_0": "SDXL 1.0",
|
||||
"graffiti": "Graffiti",
|
||||
"mini_tattoo": "Mini tattoo",
|
||||
"surrealism": "Surrealism",
|
||||
"neo_traditional": "Neo-traditional",
|
||||
"on_limbs_black": "On limbs black",
|
||||
"yamers_realistic_xl": "Yamers-realistic-xl",
|
||||
"pony_xl": "Pony-xl",
|
||||
"playground_xl": "Playground-xl",
|
||||
"anything_xl": "Anything-xl",
|
||||
"flame_design": "Flame design",
|
||||
"kawaii": "Kawaii",
|
||||
"cinematic_art": "Cinematic Art",
|
||||
"professional": "Professional",
|
||||
"flux_black_ink": "Flux Black Ink"
|
||||
}
|
||||
image_models = [*model_aliases.keys()]
|
||||
models = image_models
|
||||
|
||||
@classmethod
|
||||
def get_auth_file(cls):
|
||||
path = Path(get_cookies_dir())
|
||||
path.mkdir(exist_ok=True)
|
||||
filename = f"auth_{cls.__name__}.json"
|
||||
return path / filename
|
||||
|
||||
@classmethod
|
||||
async def create_token(cls, path: Path, proxy: str | None = None):
|
||||
async with ClientSession() as session:
|
||||
# Step 1: Generate Authentication Token
|
||||
auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
|
||||
async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
|
||||
auth_data = await auth_response.json()
|
||||
auth_token = auth_data.get("idToken")
|
||||
#refresh_token = auth_data.get("refreshToken")
|
||||
if not auth_token:
|
||||
raise ResponseError("Failed to obtain authentication token.")
|
||||
json.dump(auth_data, path.open("w"))
|
||||
return auth_data
|
||||
|
||||
@classmethod
|
||||
async def refresh_token(cls, refresh_token: str, proxy: str = None) -> tuple[str, str]:
|
||||
async with ClientSession() as session:
|
||||
payload = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
|
||||
response_data = await response.json()
|
||||
return response_data.get("id_token"), response_data.get("refresh_token")
|
||||
|
||||
@classmethod
|
||||
async def read_and_refresh_token(cls, proxy: str | None = None) -> str:
|
||||
path = cls.get_auth_file()
|
||||
if path.is_file():
|
||||
auth_data = json.load(path.open("rb"))
|
||||
diff = time.time() - os.path.getmtime(path)
|
||||
expiresIn = int(auth_data.get("expiresIn"))
|
||||
if diff < expiresIn:
|
||||
if diff > expiresIn / 2:
|
||||
auth_data["idToken"], auth_data["refreshToken"] = await cls.refresh_token(auth_data.get("refreshToken"), proxy)
|
||||
json.dump(auth_data, path.open("w"))
|
||||
return auth_data
|
||||
return await cls.create_token(path, proxy)
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
prompt: str = None,
|
||||
negative_prompt: str = "blurry, deformed hands, ugly",
|
||||
images_num: int = 1,
|
||||
guidance_scale: int = 7,
|
||||
num_inference_steps: int = 30,
|
||||
aspect_ratio: str = "1:1",
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
|
||||
# Step 1: Get Authentication Token
|
||||
auth_data = await cls.read_and_refresh_token(proxy)
|
||||
|
||||
async with ClientSession() as session:
|
||||
# Step 2: Generate Images
|
||||
image_payload = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"style": model,
|
||||
"images_num": str(images_num),
|
||||
"cfg_scale": str(guidance_scale),
|
||||
"steps": str(num_inference_steps),
|
||||
"aspect_ratio": aspect_ratio,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": auth_data.get("idToken"),
|
||||
}
|
||||
|
||||
async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
|
||||
image_data = await image_response.json()
|
||||
record_id = image_data.get("record_id")
|
||||
|
||||
if not record_id:
|
||||
raise ResponseError(f"Failed to initiate image generation: {image_data}")
|
||||
|
||||
# Step 3: Check Generation Status
|
||||
status_url = cls.status_check_url.format(record_id=record_id)
|
||||
counter = 0
|
||||
while True:
|
||||
async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
|
||||
status_data = await status_response.json()
|
||||
status = status_data.get("status")
|
||||
|
||||
if status == "DONE":
|
||||
image_urls = [image["url"] for image in status_data.get("response", [])]
|
||||
yield Reasoning(status="Finished")
|
||||
yield ImageResponse(images=image_urls, alt=prompt)
|
||||
return
|
||||
elif status in ("IN_QUEUE", "IN_PROGRESS"):
|
||||
yield Reasoning(status=("Waiting" if status == "IN_QUEUE" else "Generating") + "." * counter)
|
||||
await asyncio.sleep(5) # Poll every 5 seconds
|
||||
counter += 1
|
||||
if counter > 3:
|
||||
counter = 0
|
||||
else:
|
||||
raise ResponseError(f"Image generation failed with status: {status}")
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import requests
|
||||
from urllib.parse import quote_plus
|
||||
|
|
@ -13,7 +14,7 @@ from ..image import to_data_uri
|
|||
from ..errors import ModelNotFoundError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio
|
||||
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls
|
||||
from .. import debug
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
|
|
@ -52,7 +53,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
text_models = [default_model]
|
||||
image_models = [default_image_model]
|
||||
extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"]
|
||||
vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini"]
|
||||
vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini", "openai", "openai-large"]
|
||||
extra_text_models = vision_models
|
||||
_models_loaded = False
|
||||
model_aliases = {
|
||||
|
|
@ -138,6 +139,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
proxy: str = None,
|
||||
prompt: str = None,
|
||||
width: int = 1024,
|
||||
|
|
@ -154,6 +156,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
frequency_penalty: float = None,
|
||||
response_format: Optional[dict] = None,
|
||||
cache: bool = False,
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
cls.get_models()
|
||||
|
|
@ -193,6 +196,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
response_format=response_format,
|
||||
seed=seed,
|
||||
cache=cache,
|
||||
stream=stream,
|
||||
extra_parameters=extra_parameters,
|
||||
**kwargs
|
||||
):
|
||||
yield result
|
||||
|
||||
|
|
@ -246,7 +252,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
frequency_penalty: float,
|
||||
response_format: Optional[dict],
|
||||
seed: Optional[int],
|
||||
cache: bool
|
||||
cache: bool,
|
||||
stream: bool,
|
||||
extra_parameters: list[str],
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if not cache and seed is None:
|
||||
seed = random.randint(9999, 99999999)
|
||||
|
|
@ -267,6 +276,13 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
messages[-1] = last_message
|
||||
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
if model in cls.audio_models or stream:
|
||||
#data["voice"] = random.choice(cls.audio_models[model])
|
||||
url = cls.text_api_endpoint
|
||||
stream = False
|
||||
else:
|
||||
url = cls.openai_endpoint
|
||||
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
|
||||
data = filter_none(**{
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
|
|
@ -275,17 +291,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"top_p": top_p,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"jsonMode": json_mode,
|
||||
"stream": False,
|
||||
"stream": stream,
|
||||
"seed": seed,
|
||||
"cache": cache
|
||||
"cache": cache,
|
||||
**extra_parameters
|
||||
})
|
||||
if "gemini" in model:
|
||||
data.pop("seed")
|
||||
if model in cls.audio_models:
|
||||
#data["voice"] = random.choice(cls.audio_models[model])
|
||||
url = cls.text_api_endpoint
|
||||
else:
|
||||
url = cls.openai_endpoint
|
||||
async with session.post(url, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
if response.headers["content-type"] == "audio/mpeg":
|
||||
|
|
@ -294,11 +304,31 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
elif response.headers["content-type"].startswith("text/plain"):
|
||||
yield await response.text()
|
||||
return
|
||||
elif response.headers["content-type"].startswith("text/event-stream"):
|
||||
async for line in response.content:
|
||||
if line.startswith(b"data: "):
|
||||
if line[6:].startswith(b"[DONE]"):
|
||||
break
|
||||
result = json.loads(line[6:])
|
||||
choice = result.get("choices", [{}])[0]
|
||||
content = choice.get("delta", {}).get("content")
|
||||
if content:
|
||||
yield content
|
||||
if "usage" in result:
|
||||
yield Usage(**result["usage"])
|
||||
finish_reason = choice.get("finish_reason")
|
||||
if finish_reason:
|
||||
yield FinishReason(finish_reason)
|
||||
return
|
||||
result = await response.json()
|
||||
choice = result["choices"][0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
|
||||
if "tool_calls" in message:
|
||||
yield ToolCalls(message["tool_calls"])
|
||||
|
||||
if content is not None:
|
||||
if "</think>" in content and "<think>" not in content:
|
||||
yield "<think>"
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .mini_max import HailuoAI, MiniMax
|
|||
from .template import OpenaiTemplate, BackendApi
|
||||
|
||||
from .AllenAI import AllenAI
|
||||
from .ARTA import ARTA
|
||||
from .Blackbox import Blackbox
|
||||
from .ChatGLM import ChatGLM
|
||||
from .ChatGpt import ChatGpt
|
||||
|
|
|
|||
|
|
@ -623,8 +623,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
|
||||
page = await browser.get(cls.url)
|
||||
user_agent = await page.evaluate("window.navigator.userAgent")
|
||||
await page.select("textarea.text-token-text-primary", 240)
|
||||
await page.evaluate("document.querySelector('textarea.text-token-text-primary').value = 'Hello'")
|
||||
await page.select("#prompt-textarea", 240)
|
||||
await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'")
|
||||
await page.select("[data-testid=\"send-button\"]", 30)
|
||||
await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()")
|
||||
while True:
|
||||
body = await page.evaluate("JSON.stringify(window.__remixContext)")
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@ class Completions:
|
|||
def create(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
model: str = "",
|
||||
provider: Optional[ProviderType] = None,
|
||||
stream: Optional[bool] = False,
|
||||
proxy: Optional[str] = None,
|
||||
|
|
@ -330,7 +330,7 @@ class Completions:
|
|||
def stream(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
model: str = "",
|
||||
**kwargs
|
||||
) -> IterResponse:
|
||||
return self.create(messages, model, stream=True, **kwargs)
|
||||
|
|
@ -564,7 +564,7 @@ class AsyncCompletions:
|
|||
def create(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
model: str = "",
|
||||
provider: Optional[ProviderType] = None,
|
||||
stream: Optional[bool] = False,
|
||||
proxy: Optional[str] = None,
|
||||
|
|
@ -619,7 +619,7 @@ class AsyncCompletions:
|
|||
def stream(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
model: str = "",
|
||||
**kwargs
|
||||
) -> AsyncIterator[ChatCompletionChunk]:
|
||||
return self.create(messages, model, stream=True, **kwargs)
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ body:not(.white) a:visited{
|
|||
|
||||
.new_version {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
left: 0;
|
||||
top: 0;
|
||||
padding: 10px;
|
||||
font-weight: 500;
|
||||
|
|
@ -143,6 +143,7 @@ body:not(.white) a:visited{
|
|||
|
||||
.conversation {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 5px;
|
||||
|
|
@ -238,8 +239,7 @@ body:not(.white) a:visited{
|
|||
|
||||
#close_provider_forms {
|
||||
max-width: 210px;
|
||||
margin-left: auto;
|
||||
margin-right: 8px;
|
||||
margin-left: 12px;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
|
|
@ -1584,19 +1584,6 @@ form .field.saved .fa-xmark {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/* Basic adaptation */
|
||||
.row {
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.conversations, .settings, .conversation {
|
||||
flex: 1 1 300px;
|
||||
min-width: 0;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
/* Media queries for mobile devices */
|
||||
@media (max-width: 768px) {
|
||||
.row {
|
||||
|
|
@ -1608,11 +1595,6 @@ form .field.saved .fa-xmark {
|
|||
max-width: 100%;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.conversation {
|
||||
order: -1;
|
||||
min-height: 80vh;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
|
|
|
|||
|
|
@ -259,6 +259,10 @@ function register_message_images() {
|
|||
|
||||
const register_message_buttons = async () => {
|
||||
message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const provider_forms = document.querySelector(".provider_forms");
|
||||
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
|
||||
const provider_link = el.querySelector("a");
|
||||
|
|
@ -279,6 +283,10 @@ const register_message_buttons = async () => {
|
|||
});
|
||||
|
||||
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const message_el = get_message_el(el);
|
||||
await remove_message(window.conversation_id, message_el.dataset.index);
|
||||
message_el.remove();
|
||||
|
|
@ -286,6 +294,10 @@ const register_message_buttons = async () => {
|
|||
}));
|
||||
|
||||
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
let message_el = get_message_el(el);
|
||||
let response = await fetch(message_el.dataset.object_url);
|
||||
let copyText = await response.text();
|
||||
|
|
@ -304,6 +316,10 @@ const register_message_buttons = async () => {
|
|||
}))
|
||||
|
||||
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const elem = window.document.createElement('a');
|
||||
let filename = `chat ${new Date().toLocaleString()}.txt`.replaceAll(":", "-");
|
||||
const conversation = await get_conversation(window.conversation_id);
|
||||
|
|
@ -323,6 +339,10 @@ const register_message_buttons = async () => {
|
|||
}))
|
||||
|
||||
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const message_el = get_message_el(el);
|
||||
let audio;
|
||||
if (message_el.dataset.synthesize_url) {
|
||||
|
|
@ -344,6 +364,10 @@ const register_message_buttons = async () => {
|
|||
}));
|
||||
|
||||
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const message_el = get_message_el(el);
|
||||
el.classList.add("clicked");
|
||||
setTimeout(() => el.classList.remove("clicked"), 1000);
|
||||
|
|
@ -351,6 +375,10 @@ const register_message_buttons = async () => {
|
|||
}));
|
||||
|
||||
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
if (!el.disabled) {
|
||||
el.disabled = true;
|
||||
const message_el = get_message_el(el);
|
||||
|
|
@ -361,11 +389,19 @@ const register_message_buttons = async () => {
|
|||
));
|
||||
|
||||
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const text = get_message_el(el).innerText;
|
||||
window.open(`https://wa.me/?text=${encodeURIComponent(text)}`, '_blank');
|
||||
}));
|
||||
|
||||
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const message_el = get_message_el(el);
|
||||
el.classList.add("clicked");
|
||||
message_box.scrollTop = 0;
|
||||
|
|
@ -378,6 +414,10 @@ const register_message_buttons = async () => {
|
|||
}));
|
||||
|
||||
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => el.addEventListener("click", async () => {
|
||||
if (el.dataset.click) {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
let text_el = el.parentElement.querySelector(".reasoning_text");
|
||||
if (text_el) {
|
||||
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
|
||||
|
|
@ -569,9 +609,9 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
|
|||
}
|
||||
|
||||
// Remove history, only add new user messages
|
||||
let filtered_messages = [];
|
||||
// The message_index is null on count total tokens
|
||||
if (document.getElementById('history')?.checked && do_filter && message_index != null) {
|
||||
if (!do_continue && document.getElementById('history')?.checked && do_filter && message_index != null) {
|
||||
let filtered_messages = [];
|
||||
while (last_message = messages.pop()) {
|
||||
if (last_message["role"] == "user") {
|
||||
filtered_messages.push(last_message);
|
||||
|
|
@ -630,9 +670,9 @@ async function load_provider_parameters(provider) {
|
|||
form_el.id = form_id;
|
||||
form_el.classList.add("hidden");
|
||||
appStorage.setItem(form_el.id, JSON.stringify(parameters_storage[provider]));
|
||||
let old_form = message_box.querySelector(`#${provider}-form`);
|
||||
let old_form = document.getElementById(form_id);
|
||||
if (old_form) {
|
||||
provider_forms.removeChild(old_form);
|
||||
old_form.remove();
|
||||
}
|
||||
Object.entries(parameters_storage[provider]).forEach(([key, value]) => {
|
||||
let el_id = `${provider}-${key}`;
|
||||
|
|
@ -649,7 +689,7 @@ async function load_provider_parameters(provider) {
|
|||
saved_value = value;
|
||||
}
|
||||
field_el.innerHTML = `<span class="label">${key}:</span>
|
||||
<input type="checkbox" id="${el_id}" name="${provider}[${key}]">
|
||||
<input type="checkbox" id="${el_id}" name="${key}">
|
||||
<label for="${el_id}" class="toogle" title=""></label>
|
||||
<i class="fa-solid fa-xmark"></i>`;
|
||||
form_el.appendChild(field_el);
|
||||
|
|
@ -679,15 +719,15 @@ async function load_provider_parameters(provider) {
|
|||
placeholder = value == null ? "null" : value;
|
||||
}
|
||||
field_el.innerHTML = `<label for="${el_id}" title="">${key}:</label>`;
|
||||
if (Number.isInteger(value) && value != 1) {
|
||||
max = value >= 4096 ? 8192 : 4096;
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
|
||||
if (Number.isInteger(value)) {
|
||||
max = value == 42 || value >= 4096 ? 8192 : value >= 100 ? 4096 : value == 1 ? 10 : 100;
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${key}" value="${escapeHtml(value)}" class="slider" min="0" max="${max}" step="1"/><output>${escapeHtml(value)}</output>`;
|
||||
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
|
||||
} else if (typeof value == "number") {
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${provider}[${key}]" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
|
||||
field_el.innerHTML += `<input type="range" id="${el_id}" name="${key}" value="${escapeHtml(value)}" class="slider" min="0" max="2" step="0.1"/><output>${escapeHtml(value)}</output>`;
|
||||
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
|
||||
} else {
|
||||
field_el.innerHTML += `<textarea id="${el_id}" name="${provider}[${key}]"></textarea>`;
|
||||
field_el.innerHTML += `<textarea id="${el_id}" name="${key}"></textarea>`;
|
||||
field_el.innerHTML += `<i class="fa-solid fa-xmark"></i>`;
|
||||
input_el = field_el.querySelector("textarea");
|
||||
if (value != null) {
|
||||
|
|
@ -723,6 +763,7 @@ async function load_provider_parameters(provider) {
|
|||
input_el = field_el.querySelector("input");
|
||||
input_el.dataset.value = value;
|
||||
input_el.value = saved_value;
|
||||
input_el.nextElementSibling.value = input_el.value;
|
||||
input_el.oninput = () => {
|
||||
input_el.nextElementSibling.value = input_el.value;
|
||||
field_el.classList.add("saved");
|
||||
|
|
@ -1008,6 +1049,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
}
|
||||
await safe_remove_cancel_button();
|
||||
await register_message_images();
|
||||
await register_message_buttons();
|
||||
await load_conversations();
|
||||
regenerate_button.classList.remove("regenerate-hidden");
|
||||
}
|
||||
|
|
@ -1035,6 +1077,18 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
}
|
||||
}
|
||||
const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value);
|
||||
let extra_parameters = {};
|
||||
document.getElementById(`${provider}-form`)?.querySelectorAll(".saved input, .saved textarea").forEach(async (el) => {
|
||||
let value = el.type == "checkbox" ? el.checked : el.value;
|
||||
extra_parameters[el.name] = value;
|
||||
if (el.type == "textarea") {
|
||||
try {
|
||||
extra_parameters[el.name] = await JSON.parse(value);
|
||||
} catch (e) {
|
||||
}
|
||||
}
|
||||
});
|
||||
console.log(extra_parameters);
|
||||
await api("conversation", {
|
||||
id: message_id,
|
||||
conversation_id: window.conversation_id,
|
||||
|
|
@ -1048,6 +1102,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
api_key: api_key,
|
||||
api_base: api_base,
|
||||
ignored: ignored,
|
||||
...extra_parameters
|
||||
}, Object.values(image_storage), message_id, scroll, finish_message);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
|
|
|
|||
|
|
@ -89,25 +89,17 @@ class Api:
|
|||
ensure_images_dir()
|
||||
return send_from_directory(os.path.abspath(images_dir), name)
|
||||
|
||||
def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
|
||||
def _prepare_conversation_kwargs(self, json_data: dict):
|
||||
kwargs = {**json_data}
|
||||
model = json_data.get('model')
|
||||
provider = json_data.get('provider')
|
||||
messages = json_data.get('messages')
|
||||
api_key = json_data.get("api_key")
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
api_base = json_data.get("api_base")
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
kwargs["tool_calls"] = [{
|
||||
"function": {
|
||||
"name": "bucket_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}]
|
||||
web_search = json_data.get('web_search')
|
||||
if web_search:
|
||||
kwargs["web_search"] = web_search
|
||||
action = json_data.get('action')
|
||||
if action == "continue":
|
||||
kwargs["tool_calls"].append({
|
||||
|
|
@ -117,19 +109,13 @@ class Api:
|
|||
"type": "function"
|
||||
})
|
||||
conversation = json_data.get("conversation")
|
||||
if conversation is not None:
|
||||
if isinstance(conversation, dict):
|
||||
kwargs["conversation"] = JsonConversation(**conversation)
|
||||
else:
|
||||
conversation_id = json_data.get("conversation_id")
|
||||
if conversation_id and provider:
|
||||
if provider in conversations and conversation_id in conversations[provider]:
|
||||
kwargs["conversation"] = conversations[provider][conversation_id]
|
||||
|
||||
if json_data.get("ignored"):
|
||||
kwargs["ignored"] = json_data["ignored"]
|
||||
if json_data.get("action"):
|
||||
kwargs["action"] = json_data["action"]
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
|
|
|
|||
|
|
@ -106,17 +106,16 @@ class Backend_Api(Api):
|
|||
Returns:
|
||||
Response: A Flask response object for streaming.
|
||||
"""
|
||||
kwargs = {}
|
||||
if "json" in request.form:
|
||||
json_data = json.loads(request.form['json'])
|
||||
else:
|
||||
json_data = request.json
|
||||
if "files" in request.files:
|
||||
images = []
|
||||
for file in request.files.getlist('files'):
|
||||
if file.filename != '' and is_allowed_extension(file.filename):
|
||||
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
|
||||
kwargs['images'] = images
|
||||
if "json" in request.form:
|
||||
json_data = json.loads(request.form['json'])
|
||||
else:
|
||||
json_data = request.json
|
||||
json_data['images'] = images
|
||||
|
||||
if app.demo and not json_data.get("provider"):
|
||||
model = json_data.get("model")
|
||||
|
|
@ -126,9 +125,7 @@ class Backend_Api(Api):
|
|||
if not model or model == "default":
|
||||
json_data["model"] = models.demo_models["default"][0].name
|
||||
json_data["provider"] = random.choice(models.demo_models["default"][1])
|
||||
if "images" in json_data:
|
||||
kwargs["images"] = json_data["images"]
|
||||
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
|
||||
kwargs = self._prepare_conversation_kwargs(json_data)
|
||||
return self.app.response_class(
|
||||
self._create_response_stream(
|
||||
kwargs,
|
||||
|
|
|
|||
|
|
@ -21,12 +21,12 @@ from .api import Api
|
|||
|
||||
class JsApi(Api):
|
||||
|
||||
def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None, **kwargs) -> Iterator:
|
||||
def get_conversation(self, options: dict, message_id: str = None, scroll: bool = None) -> Iterator:
|
||||
window = webview.windows[0]
|
||||
if hasattr(self, "image") and self.image is not None:
|
||||
kwargs["image"] = open(self.image, "rb")
|
||||
options["image"] = open(self.image, "rb")
|
||||
for message in self._create_response_stream(
|
||||
self._prepare_conversation_kwargs(options, kwargs),
|
||||
self._prepare_conversation_kwargs(options),
|
||||
options.get("conversation_id"),
|
||||
options.get('provider')
|
||||
):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ SAFE_PARAMETERS = [
|
|||
"api_key", "api_base", "seed", "width", "height",
|
||||
"proof_token", "max_retries", "web_search",
|
||||
"guidance_scale", "num_inference_steps", "randomize_seed",
|
||||
"safe", "enhance", "private",
|
||||
"safe", "enhance", "private", "aspect_ratio", "images_num",
|
||||
]
|
||||
|
||||
BASIC_PARAMETERS = {
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@ from __future__ import annotations
|
|||
import json
|
||||
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..providers.asyncio import to_async_iterator
|
||||
from ..client.service import get_model_and_provider
|
||||
from ..client.helper import filter_json
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
from .response import ToolCalls, FinishReason
|
||||
from .response import ToolCalls, FinishReason, Usage
|
||||
|
||||
class ToolSupportProvider(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
|
@ -45,6 +44,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||
|
||||
finish = None
|
||||
chunks = []
|
||||
has_usage = False
|
||||
async for chunk in provider.get_async_create_function()(
|
||||
model,
|
||||
messages,
|
||||
|
|
@ -53,14 +53,20 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||
response_format=response_format,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(chunk, FinishReason):
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk)
|
||||
elif isinstance(chunk, Usage):
|
||||
yield chunk
|
||||
has_usage = True
|
||||
elif isinstance(chunk, FinishReason):
|
||||
finish = chunk
|
||||
break
|
||||
elif isinstance(chunk, str):
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
if not has_usage:
|
||||
yield Usage(completion_tokens=len(chunks), total_tokens=len(chunks))
|
||||
|
||||
chunks = "".join(chunks)
|
||||
if tools is not None:
|
||||
yield ToolCalls([{
|
||||
|
|
@ -72,5 +78,6 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||
}
|
||||
}])
|
||||
yield chunks
|
||||
|
||||
if finish is not None:
|
||||
yield finish
|
||||
|
|
@ -59,7 +59,7 @@ class ToolHandler:
|
|||
def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]:
|
||||
"""Process continue tool requests"""
|
||||
kwargs = {}
|
||||
if provider not in ("OpenaiAccount", "HuggingFace"):
|
||||
if provider not in ("OpenaiAccount", "HuggingFaceAPI"):
|
||||
messages = messages.copy()
|
||||
last_line = messages[-1]["content"].strip().splitlines()[-1]
|
||||
content = f"Carry on from this point:\n{last_line}"
|
||||
|
|
@ -85,12 +85,10 @@ class ToolHandler:
|
|||
has_bucket = True
|
||||
message["content"] = new_message_content
|
||||
|
||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||
if "\nSource: " in messages[-1]["content"]:
|
||||
if isinstance(messages[-1]["content"], dict):
|
||||
messages[-1]["content"]["content"] += BUCKET_INSTRUCTIONS
|
||||
else:
|
||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
||||
last_message_content = messages[-1]["content"]
|
||||
if has_bucket and isinstance(last_message_content, str):
|
||||
if "\nSource: " in last_message_content:
|
||||
messages[-1]["content"] = last_message_content + BUCKET_INSTRUCTIONS
|
||||
|
||||
return messages
|
||||
|
||||
|
|
@ -309,9 +307,10 @@ def iter_run_tools(
|
|||
if new_message_content != message["content"]:
|
||||
has_bucket = True
|
||||
message["content"] = new_message_content
|
||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||
if "\nSource: " in messages[-1]["content"]:
|
||||
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
|
||||
last_message = messages[-1]["content"]
|
||||
if has_bucket and isinstance(last_message, str):
|
||||
if "\nSource: " in last_message:
|
||||
messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
|
||||
|
||||
# Process response chunks
|
||||
thinking_start_time = 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue