From ce500f0d49bacc2874b498b67fc4942d26daddc0 Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Wed, 26 Mar 2025 01:32:05 +0100 Subject: [PATCH] Set default model in HuggingFaceMedia Improve handling of shared chats Show api_key input if required --- etc/examples/video.py | 19 ++++++++++++ g4f/Provider/hf/HuggingFaceMedia.py | 12 +++++--- g4f/Provider/needs_auth/Grok.py | 3 +- g4f/Provider/template/OpenaiTemplate.py | 4 +-- g4f/api/__init__.py | 5 ++++ g4f/client/__init__.py | 24 +++++++++------ g4f/client/image_models.py | 38 +++++++++++++++++++---- g4f/gui/client/index.html | 9 ++++-- g4f/gui/client/qrcode.html | 4 +-- g4f/gui/client/static/css/style.css | 16 +++++----- g4f/gui/client/static/js/chat.v1.js | 39 ++++++++++++++++++------ g4f/gui/server/api.py | 7 +++-- g4f/gui/server/backend_api.py | 40 ++++++++++++++++++++----- g4f/image/__init__.py | 18 ++++------- g4f/image/copy_images.py | 6 ++-- g4f/models.py | 39 ++++++------------------ g4f/providers/base_provider.py | 6 +++- g4f/requests/raise_for_status.py | 28 +++++++++-------- 18 files changed, 206 insertions(+), 111 deletions(-) create mode 100644 etc/examples/video.py diff --git a/etc/examples/video.py b/etc/examples/video.py new file mode 100644 index 00000000..0d5e787b --- /dev/null +++ b/etc/examples/video.py @@ -0,0 +1,19 @@ +import g4f.Provider +from g4f.client import Client + +client = Client( + provider=g4f.Provider.HuggingFaceMedia, + api_key="hf_***" # Your API key here +) + +video_models = client.models.get_video() + +print(video_models) + +result = client.media.generate( + model=video_models[0], + prompt="G4F AI technology is the best in the world.", + response_format="url" +) + +print(result.data[0].url) \ No newline at end of file diff --git a/g4f/Provider/hf/HuggingFaceMedia.py b/g4f/Provider/hf/HuggingFaceMedia.py index e965e4d5..be6788b6 100644 --- a/g4f/Provider/hf/HuggingFaceMedia.py +++ b/g4f/Provider/hf/HuggingFaceMedia.py @@ -66,6 +66,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin): for provider_data in provider_keys: prepend_models.append(f"{model}:{provider_data.get('provider')}") cls.models = prepend_models + [model for model in new_models if model not in prepend_models] + cls.image_models = [model for model, task in cls.task_mapping.items() if task == "text-to-image"] + cls.video_models = [model for model, task in cls.task_mapping.items() if task == "text-to-video"] else: cls.models = [] return cls.models @@ -99,12 +101,14 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin): prompt: str = None, proxy: str = None, timeout: int = 0, - aspect_ratio: str = "1:1", + aspect_ratio: str = None, **kwargs ): selected_provider = None - if ":" in model: + if model and ":" in model: model, selected_provider = model.split(":", 1) + elif not model: + model = cls.get_models()[0] provider_mapping = await cls.get_mapping(model, api_key) headers = { 'Accept-Encoding': 'gzip, deflate', @@ -133,11 +137,11 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin): extra_data = { "num_inference_steps": 20, "resolution": "480p", - "aspect_ratio": aspect_ratio, + "aspect_ratio": "16:9" if aspect_ratio is None else aspect_ratio, **extra_data } else: - extra_data = use_aspect_ratio(extra_data, aspect_ratio) + extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio) if provider_key == "fal-ai": url = f"{api_base}/{provider_id}" data = { diff --git a/g4f/Provider/needs_auth/Grok.py b/g4f/Provider/needs_auth/Grok.py index 46f1a24e..53c921da 100644 --- a/g4f/Provider/needs_auth/Grok.py +++ b/g4f/Provider/needs_auth/Grok.py @@ -30,6 +30,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin): default_model = "grok-3" models = [default_model, "grok-3-thinking", "grok-2"] + model_aliases = {"grok-3-r1": "grok-3-thinking"} @classmethod async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator: @@ -73,7 +74,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin): "sendFinalMetadata": True, "customInstructions": "", "deepsearchPreset": "", - "isReasoning": model.endswith("-thinking"), + "isReasoning": model.endswith("-thinking") or model.endswith("-r1"), } @classmethod diff --git a/g4f/Provider/template/OpenaiTemplate.py b/g4f/Provider/template/OpenaiTemplate.py index 2b5ca544..eb253bc6 100644 --- a/g4f/Provider/template/OpenaiTemplate.py +++ b/g4f/Provider/template/OpenaiTemplate.py @@ -92,7 +92,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin } async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response: data = await response.json() - cls.raise_error(data) + cls.raise_error(data, response.status) await raise_for_status(response) yield ImageResponse([image["url"] for image in data["data"]], prompt) return @@ -135,7 +135,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json") if content_type.startswith("application/json"): data = await response.json() - cls.raise_error(data) + cls.raise_error(data, response.status) await raise_for_status(response) choice = data["choices"][0] if "content" in choice["message"] and choice["message"]["content"]: diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 2e978101..87e7a1fd 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -10,6 +10,7 @@ from email.utils import formatdate import os.path import hashlib import asyncio +from urllib.parse import quote_plus from fastapi import FastAPI, Response, Request, UploadFile, Depends from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse @@ -562,6 +563,10 @@ class Api: }) async def get_media(filename, request: Request): target = os.path.join(images_dir, os.path.basename(filename)) + if not os.path.isfile(target): + other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename))) + if os.path.isfile(other_name): + target = other_name ext = os.path.splitext(filename)[1][1:] mime_type = EXTENSIONS_MAP.get(ext) stat_result = SimpleNamespace() diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 8454d3db..33b795a9 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from ..tools.run_tools import async_iter_run_tools, iter_run_tools from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel -from .image_models import ImageModels +from .image_models import MediaModels from .types import IterResponse, ImageProvider, Client as BaseClient from .service import get_model_and_provider, convert_to_provider from .helper import find_stop, filter_json, filter_none, safe_aclose @@ -267,8 +267,11 @@ class Client(BaseClient): ) -> None: super().__init__(**kwargs) self.chat: Chat = Chat(self, provider) + if image_provider is None: + image_provider = provider + self.models: MediaModels = MediaModels(self, image_provider) self.images: Images = Images(self, image_provider) - self.media: Images = Images(self, image_provider) + self.media: Images = self.images class Completions: def __init__(self, client: Client, provider: Optional[ProviderType] = None): @@ -349,7 +352,6 @@ class Images: def __init__(self, client: Client, provider: Optional[ProviderType] = None): self.client: Client = client self.provider: Optional[ProviderType] = provider - self.models: ImageModels = ImageModels(client) def generate( self, @@ -369,7 +371,7 @@ class Images: if provider is None: provider_handler = self.provider if provider_handler is None: - provider_handler = self.models.get(model, default) + provider_handler = self.client.models.get(model, default) elif isinstance(provider, str): provider_handler = convert_to_provider(provider) else: @@ -385,19 +387,21 @@ class Images: provider: Optional[ProviderType] = None, response_format: Optional[str] = None, proxy: Optional[str] = None, + api_key: Optional[str] = None, **kwargs ) -> ImagesResponse: provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__ if proxy is None: proxy = self.client.proxy - + if api_key is None: + api_key = self.client.api_key error = None response = None if isinstance(provider_handler, IterListProvider): for provider in provider_handler.providers: try: - response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs) + response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs) if response is not None: provider_name = provider.__name__ break @@ -405,7 +409,7 @@ class Images: error = e debug.error(f"{provider.__name__} {type(e).__name__}: {e}") else: - response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) + response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs) if isinstance(response, MediaResponse): return await self._process_image_response( @@ -534,7 +538,7 @@ class Images: else: # Save locally for None (default) case images = await copy_media(response.get_list(), response.get("cookies"), proxy) - images = [Image.model_construct(url=f"/media/{os.path.basename(image)}", revised_prompt=response.alt) for image in images] + images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in images] return ImagesResponse.model_construct( created=int(time.time()), @@ -552,6 +556,9 @@ class AsyncClient(BaseClient): ) -> None: super().__init__(**kwargs) self.chat: AsyncChat = AsyncChat(self, provider) + if image_provider is None: + image_provider = provider + self.models: MediaModels = MediaModels(self, image_provider) self.images: AsyncImages = AsyncImages(self, image_provider) self.media: AsyncImages = self.images @@ -635,7 +642,6 @@ class AsyncImages(Images): def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None): self.client: AsyncClient = client self.provider: Optional[ProviderType] = provider - self.models: ImageModels = ImageModels(client) async def generate( self, diff --git a/g4f/client/image_models.py b/g4f/client/image_models.py index 6795a83b..7146f5f2 100644 --- a/g4f/client/image_models.py +++ b/g4f/client/image_models.py @@ -1,15 +1,43 @@ from __future__ import annotations -from ..models import ModelUtils +from ..models import ModelUtils, ImageModel from ..Provider import ProviderUtils +from ..providers.types import ProviderType -class ImageModels(): - def __init__(self, client): +class MediaModels(): + def __init__(self, client, provider: ProviderType = None): self.client = client + self.provider = provider - def get(self, name, default=None): + def get(self, name, default=None) -> ProviderType: if name in ModelUtils.convert: return ModelUtils.convert[name].best_provider if name in ProviderUtils.convert: return ProviderUtils.convert[name] - return default \ No newline at end of file + return default + + def get_all(self, api_key: str = None, **kwargs) -> list[str]: + if self.provider is None: + return [] + if api_key is None: + api_key = self.client.api_key + return self.provider.get_models( + **kwargs, + **{} if api_key is None else {"api_key": api_key} + ) + + def get_image(self, **kwargs) -> list[str]: + if self.provider is None: + return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)] + self.get_all(**kwargs) + if hasattr(self.provider, "image_models"): + return self.provider.image_models + return [] + + def get_video(self, **kwargs) -> list[str]: + if self.provider is None: + return [] + self.get_all(**kwargs) + if hasattr(self.provider, "video_models"): + return self.provider.video_models + return [] \ No newline at end of file diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index fb53e023..b85af2e5 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -61,9 +61,12 @@ const gpt_image = 'your avatar'; - - - + G4F Chat diff --git a/g4f/gui/client/qrcode.html b/g4f/gui/client/qrcode.html index 74eefb5a..30754771 100644 --- a/g4f/gui/client/qrcode.html +++ b/g4f/gui/client/qrcode.html @@ -7,9 +7,9 @@ diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index 358493bc..384115a8 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -881,7 +881,7 @@ input.model:hover padding: var(--inner-gap) 28px; } -#systemPrompt, #chatPrompt, .settings textarea, form textarea { +#systemPrompt, #chatPrompt, .settings textarea, form textarea, .chat-body textarea { font-size: 15px; color: var(--colour-3); outline: none; @@ -1305,7 +1305,7 @@ form textarea { padding: 0; } -.settings textarea { +.settings textarea, .chat-body textarea { height: 30px; min-height: 30px; padding: 6px; @@ -1315,7 +1315,7 @@ form textarea { text-wrap: nowrap; } -form .field .fa-xmark { +.field .fa-xmark { line-height: 20px; cursor: pointer; margin-left: auto; @@ -1323,11 +1323,11 @@ form .field .fa-xmark { margin-top: 0; } -form .field.saved .fa-xmark { +.field.saved .fa-xmark { color: var(--accent) } -.settings .field, form .field { +.settings .field, form .field, .chat-body .field { padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0; } @@ -1359,7 +1359,7 @@ form .field.saved .fa-xmark { border: none; } -.settings input, form input { +.settings input, form input, .chat-body input { background-color: transparent; padding: 2px; border: none; @@ -1368,11 +1368,11 @@ form .field.saved .fa-xmark { color: var(--colour-3); } -.settings input:focus, form input:focus { +.settings input:focus, form input:focus, .chat-body input:focus { outline: none; } -.settings .label, form .label, .settings label, form label { +.settings .label, form .label, .settings label, form label, .chat-body label { font-size: 15px; margin-left: var(--inner-gap); } diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index a2ea836e..9d6422cb 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -28,7 +28,7 @@ const switchInput = document.getElementById("switch"); const searchButton = document.getElementById("search"); const paperclip = document.querySelector(".user-input .fa-paperclip"); -const optionElementsSelector = ".settings input, .settings textarea, #model, #model2, #provider"; +const optionElementsSelector = ".settings input, .settings textarea, .chat-body input, #model, #model2, #provider"; let provider_storage = {}; let message_storage = {}; @@ -153,7 +153,7 @@ const iframe_close = Object.assign(document.createElement("button"), { }); iframe_close.onclick = () => iframe_container.classList.add("hidden"); iframe_container.appendChild(iframe_close); -chat.appendChild(iframe_container); +document.body.appendChild(iframe_container); class HtmlRenderPlugin { constructor(options = {}) { @@ -843,6 +843,16 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m conversation.data[key] = value; } await save_conversation(conversation_id, conversation); + } else if (message.type == "auth") { + error_storage[message_id] = message.message + content_map.inner.innerHTML += markdown_render(`**An error occured:** ${message.message}`); + let provider = provider_storage[message_id]?.name; + let configEl = document.querySelector(`.settings .${provider}-api_key`); + if (configEl) { + configEl = configEl.parentElement.cloneNode(true); + content_map.content.appendChild(configEl); + await register_settings_storage(); + } } else if (message.type == "provider") { provider_storage[message_id] = message.provider; let provider_el = content_map.content.querySelector('.provider'); @@ -1122,10 +1132,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi let api_key; if (is_demo && !provider) { api_key = localStorage.getItem("HuggingFace-api_key"); - if (!api_key) { - location.href = "/"; - return; - } } else { api_key = get_api_key_by_provider(provider); } @@ -1221,6 +1227,7 @@ function sanitize(input, replacement) { } async function set_conversation_title(conversation_id, title) { + window.chat_id = null; conversation = await get_conversation(conversation_id) conversation.new_title = title; const new_id = sanitize(title, " "); @@ -1742,12 +1749,22 @@ const load_conversations = async () => { let html = []; conversations.forEach((conversation) => { + // const length = conversation.items.map((item) => ( + // !item.content.toLowerCase().includes("hello") && + // !item.content.toLowerCase().includes("hi") && + // item.content + // ) ? 1 : 0).reduce((a,b)=>a+b, 0); + // if (!length) { + // appStorage.removeItem(`conversation:${conversation.id}`); + // return; + // } + const shareIcon = (conversation.id == window.start_id && window.chat_id) ? '': ''; html.push(`
${conversation.updated ? toLocaleDateString(conversation.updated) : ""} - ${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)} + ${shareIcon} ${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}