diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 22d3177f1..7bcba15e3 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,41 +1,36 @@ import functools +import html +import json import os.path +import re import urllib.parse from base64 import b64decode from io import BytesIO from pathlib import Path -from typing import Optional, Callable -from dataclasses import dataclass -import re -from starlette.responses import Response, FileResponse, JSONResponse +from typing import Callable, Optional -from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util -from modules.images import read_info_from_image, save_image_with_geninfo import gradio as gr -import json -import html from fastapi.exceptions import HTTPException from PIL import Image +from starlette.responses import FileResponse, JSONResponse, Response +from modules import (errors, extra_networks, shared, + ui_extra_networks_user_metadata, util) +from modules.images import read_info_from_image, save_image_with_geninfo from modules.infotext_utils import image_from_url_text extra_pages = [] allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] + @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) def allowed_preview_extensions(): - return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) - - -@dataclass -class ExtraNetworksItem: - """Wrapper for dictionaries representing ExtraNetworks items.""" - item: dict + return allowed_preview_extensions_with_extra((shared.opts.samples_format,)) class ListItem: @@ -44,6 +39,7 @@ class ListItem: id [str]: The ID of this list item. html [str]: The HTML string for this item. """ + def __init__(self, _id: str, _html: str) -> None: self.id = _id self.html = _html @@ -56,6 +52,7 @@ class CardListItem(ListItem): sort_keys [dict]: Nested dict where keys are sort modes and values are sort keys. search_terms [str]: String containing multiple search terms joined with spaces. """ + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -70,6 +67,7 @@ class TreeListItem(ListItem): visible [bool]: Whether the item should be shown in the list. expanded [bool]: Whether the item children should be shown. """ + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -120,7 +118,6 @@ class DirectoryTreeNode: else: self.item = items.get(self.abspath, None) - def flatten(self, res: dict, dirs_only: bool = False) -> None: """Flattens the keys/values of the tree nodes into a dictionary. @@ -147,6 +144,7 @@ class DirectoryTreeNode: for child in self.children: child.apply(fn) + def register_page(page): """registers extra networks page for the UI @@ -156,6 +154,7 @@ def register_page(page): allowed_dirs.clear() allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + def get_page_by_name(extra_networks_tabname: str = "") -> "ExtraNetworksPage": for page in extra_pages: if page.extra_networks_tabname == extra_networks_tabname: @@ -186,7 +185,7 @@ def fetch_cover_images(extra_networks_tabname: str = "", item: str = "", index: if metadata is None: raise HTTPException(status_code=404, detail="File not found") - cover_images = json.loads(metadata.get('ssmd_cover_images', {})) + cover_images = json.loads(metadata.get("ssmd_cover_images", {})) image = cover_images[index] if index < len(cover_images) else None if not image: raise HTTPException(status_code=404, detail="File not found") @@ -199,6 +198,7 @@ def fetch_cover_images(extra_networks_tabname: str = "", item: str = "", index: except Exception as err: raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err + def init_tree_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONResponse: page = get_page_by_name(extra_networks_tabname) @@ -209,6 +209,7 @@ def init_tree_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONR return JSONResponse(data, status_code=200) + def fetch_tree_data( extra_networks_tabname: str = "", div_ids: str = "", @@ -247,6 +248,7 @@ def init_cards_data(tabname: str = "", extra_networks_tabname: str = "") -> JSON return JSONResponse(data, status_code=200) + def page_is_ready(extra_networks_tabname: str = "") -> JSONResponse: page = get_page_by_name(extra_networks_tabname) @@ -259,6 +261,7 @@ def page_is_ready(extra_networks_tabname: str = "") -> JSONResponse: except Exception as exc: return JSONResponse({"error": str(exc)}, status_code=500) + def get_metadata(extra_networks_tabname: str = "", item: str = "") -> JSONResponse: try: page = get_page_by_name(extra_networks_tabname) @@ -271,7 +274,7 @@ def get_metadata(extra_networks_tabname: str = "", item: str = "") -> JSONRespon # those are cover images, and they are too big to display in UI as text # FIXME: WHY WAS THIS HERE? - #metadata = {i: metadata[i] for i in metadata if i != 'ssmd_cover_images'} + # metadata = {i: metadata[i] for i in metadata if i != 'ssmd_cover_images'} return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) @@ -294,6 +297,7 @@ def get_single_card(tabname: str = "", extra_networks_tabname: str = "", name: s return JSONResponse({"html": item_html}) + def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"]) @@ -305,8 +309,9 @@ def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/fetch-cards-data", fetch_cards_data, methods=["GET"]) app.add_api_route("/sd_extra_networks/page-is-ready", page_is_ready, methods=["GET"]) + def quote_js(s): - s = s.replace('\\', '\\\\') + s = s.replace("\\", "\\\\") s = s.replace('"', '\\"') return f'"{s}"' @@ -359,13 +364,13 @@ class ExtraNetworksPage: item["user_metadata"] = metadata def link_preview(self, filename): - quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) + quoted_filename = urllib.parse.quote(filename.replace("\\", "/")) mtime, _ = self.lister.mctime(filename) return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" def search_terms_from_path(self, filename, possible_directories=None): abspath = os.path.abspath(filename) - for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + for parentdir in possible_directories if possible_directories is not None else self.allowed_directories_for_previews(): parentdir = os.path.dirname(os.path.abspath(parentdir)) if abspath.startswith(parentdir): return os.path.relpath(abspath, parentdir) @@ -413,7 +418,7 @@ class ExtraNetworksPage: # Boolean data attributes only need a key when true. if v: data_attributes_str += f"{k} " - elif v not in [None, "", "\'\'", "\"\""]: + elif v not in [None, "", "''", '""']: data_attributes_str += f"{k}={v} " res = self.tree_row_tpl.format( @@ -462,7 +467,6 @@ class ExtraNetworksPage: btn_metadata=btn_metadata, ) - def create_card_html( self, tabname: str, @@ -545,7 +549,7 @@ class ExtraNetworksPage: # Boolean data attributes only need a key when true. if v: data_attributes_str += f"{k} " - elif v not in [None, "", "\'\'", "\"\""]: + elif v not in [None, "", "''", '""']: data_attributes_str += f"{k}={v} " return self.card_tpl.format( @@ -563,10 +567,7 @@ class ExtraNetworksPage: for i, item in enumerate(self.items.values()): div_id = str(i) card_html = self.create_card_html(tabname=tabname, item=item, div_id=div_id) - sort_keys = { - k.strip().lower().replace(" ", "_"): html.escape(str(v)) - for k, v in item.get("sort_keys", {}).items() - } + sort_keys = {k.strip().lower().replace(" ", "_"): html.escape(str(v)) for k, v in item.get("sort_keys", {}).items()} search_terms = item.get("search_terms", []) self.cards[div_id] = CardListItem(div_id, card_html) self.cards[div_id].sort_keys = sort_keys @@ -698,7 +699,6 @@ class ExtraNetworksPage: } return res - def create_dirs_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders.""" # Flatten each root into a single dict. Only get the directories for buttons. @@ -713,13 +713,18 @@ class ExtraNetworksPage: tree.values(), key=lambda x: shared.natural_sort_key(x.relpath), ) - dirs_html = "".join([ - self.btn_dirs_view_item_tpl.format(**{ - "extra_class": "search-all" if node.relpath == "" else "", - "tabname_full": f"{tabname}_{self.extra_networks_tabname}", - "path": html.escape(node.relpath), - }) for node in dir_nodes - ]) + dirs_html = "".join( + [ + self.btn_dirs_view_item_tpl.format( + **{ + "extra_class": "search-all" if node.relpath == "" else "", + "tabname_full": f"{tabname}_{self.extra_networks_tabname}", + "path": html.escape(node.relpath), + } + ) + for node in dir_nodes + ] + ) return dirs_html @@ -769,22 +774,24 @@ class ExtraNetworksPage: dirs_view_en = shared.opts.extra_networks_dirs_view_default_enabled tree_view_en = shared.opts.extra_networks_tree_view_default_enabled - return self.pane_tpl.format(**{ - "tabname": tabname, - "extra_networks_tabname": self.extra_networks_tabname, - "data_sort_dir": sort_dir, - "btn_sort_mode_path_data_attributes": "data-selected" if sort_mode == "path" else "", - "btn_sort_mode_name_data_attributes": "data-selected" if sort_mode == "name" else "", - "btn_sort_mode_date_created_data_attributes": "data-selected" if sort_mode == "date_created" else "", - "btn_sort_mode_date_modified_data_attributes": "data-selected" if sort_mode == "date_modified" else "", - "btn_dirs_view_data_attributes": "data-selected" if dirs_view_en else "", - "btn_tree_view_data_attributes": "data-selected" if tree_view_en else "", - "dirs_view_hidden_cls": "" if dirs_view_en else "hidden", - "tree_view_hidden_cls": "" if tree_view_en else "hidden", - "tree_view_style": f"flex-basis: {shared.opts.extra_networks_tree_view_default_width}px;", - "cards_view_style": "flex-grow: 1;", - "dirs_html": dirs_html, - }) + return self.pane_tpl.format( + **{ + "tabname": tabname, + "extra_networks_tabname": self.extra_networks_tabname, + "data_sort_dir": sort_dir, + "btn_sort_mode_path_data_attributes": "data-selected" if sort_mode == "path" else "", + "btn_sort_mode_name_data_attributes": "data-selected" if sort_mode == "name" else "", + "btn_sort_mode_date_created_data_attributes": "data-selected" if sort_mode == "date_created" else "", + "btn_sort_mode_date_modified_data_attributes": "data-selected" if sort_mode == "date_modified" else "", + "btn_dirs_view_data_attributes": "data-selected" if dirs_view_en else "", + "btn_tree_view_data_attributes": "data-selected" if tree_view_en else "", + "dirs_view_hidden_cls": "" if dirs_view_en else "hidden", + "tree_view_hidden_cls": "" if tree_view_en else "hidden", + "tree_view_style": f"flex-basis: {shared.opts.extra_networks_tree_view_default_width}px;", + "cards_view_style": "flex-grow: 1;", + "dirs_html": dirs_html, + } + ) def create_item(self, name, index=None): raise NotImplementedError() @@ -827,7 +834,11 @@ class ExtraNetworksPage: """ file = f"{path}.safetensors" - if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0: + if ( + self.lister.exists(file) + and "ssmd_cover_images" in metadata + and len(list(filter(None, json.loads(metadata["ssmd_cover_images"])))) > 0 + ): return f"./sd_extra_networks/cover-images?extra_networks_tabname={self.extra_networks_tabname}&item={name}" return None @@ -856,9 +867,13 @@ def initialize(): def register_default_pages(): - from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion - from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks - from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints + from modules.ui_extra_networks_checkpoints import \ + ExtraNetworksPageCheckpoints + from modules.ui_extra_networks_hypernets import \ + ExtraNetworksPageHypernetworks + from modules.ui_extra_networks_textual_inversion import \ + ExtraNetworksPageTextualInversion + register_page(ExtraNetworksPageTextualInversion()) register_page(ExtraNetworksPageHypernetworks()) register_page(ExtraNetworksPageCheckpoints()) @@ -919,8 +934,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.user_metadata_editors.append(editor) related_tabs.append(tab) - ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False) - ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False) + ui.button_save_preview = gr.Button("Save preview", elem_id=f"{tabname}_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox("Preview save filename", elem_id=f"{tabname}_preview_filename", visible=False) for tab in unrelated_tabs: tab.select( @@ -953,14 +968,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): return ui.pages_contents button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_{page.extra_networks_tabname}_extra_refresh_internal", visible=False) - button_refresh.click( - fn=refresh, - inputs=[], - outputs=ui.pages, - ).then( - fn=lambda: None, - _js='setupAllResizeHandles' - ).then( + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages,).then(fn=lambda: None, _js="setupAllResizeHandles").then( fn=lambda: None, _js=f"function(){{extraNetworksRefreshTab('{tabname}_{page.extra_networks_tabname}');}}", ) @@ -973,11 +981,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): create_html() return ui.pages_contents - interface.load( - fn=pages_html, - inputs=[], - outputs=ui.pages, - ).then( + interface.load(fn=pages_html, inputs=[], outputs=ui.pages,).then( fn=lambda: None, _js="setupAllResizeHandles", ) @@ -1014,7 +1018,7 @@ def setup_ui(ui, gallery): is_allowed = True break - assert is_allowed, f'writing to {filename} is not allowed' + assert is_allowed, f"writing to {filename} is not allowed" save_image_with_geninfo(image, geninfo, filename) @@ -1024,7 +1028,7 @@ def setup_ui(ui, gallery): fn=save_preview, _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], - outputs=[*ui.pages] + outputs=[*ui.pages], ) for editor in ui.user_metadata_editors: