run formatter

This commit is contained in:
Sj-Si 2024-04-15 13:17:04 -04:00
parent 634f7bc920
commit fc20d1df0d

View file

@ -1,41 +1,36 @@
import functools
import html
import json
import os.path
import re
import urllib.parse
from base64 import b64decode
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable
from dataclasses import dataclass
import re
from starlette.responses import Response, FileResponse, JSONResponse
from typing import Callable, Optional
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util
from modules.images import read_info_from_image, save_image_with_geninfo
import gradio as gr
import json
import html
from fastapi.exceptions import HTTPException
from PIL import Image
from starlette.responses import FileResponse, JSONResponse, Response
from modules import (errors, extra_networks, shared,
ui_extra_networks_user_metadata, util)
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.infotext_utils import image_from_url_text
extra_pages = []
allowed_dirs = set()
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
@functools.cache
def allowed_preview_extensions_with_extra(extra_extensions=None):
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
def allowed_preview_extensions():
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
@dataclass
class ExtraNetworksItem:
"""Wrapper for dictionaries representing ExtraNetworks items."""
item: dict
return allowed_preview_extensions_with_extra((shared.opts.samples_format,))
class ListItem:
@ -44,6 +39,7 @@ class ListItem:
id [str]: The ID of this list item.
html [str]: The HTML string for this item.
"""
def __init__(self, _id: str, _html: str) -> None:
self.id = _id
self.html = _html
@ -56,6 +52,7 @@ class CardListItem(ListItem):
sort_keys [dict]: Nested dict where keys are sort modes and values are sort keys.
search_terms [str]: String containing multiple search terms joined with spaces.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -70,6 +67,7 @@ class TreeListItem(ListItem):
visible [bool]: Whether the item should be shown in the list.
expanded [bool]: Whether the item children should be shown.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -120,7 +118,6 @@ class DirectoryTreeNode:
else:
self.item = items.get(self.abspath, None)
def flatten(self, res: dict, dirs_only: bool = False) -> None:
"""Flattens the keys/values of the tree nodes into a dictionary.
@ -147,6 +144,7 @@ class DirectoryTreeNode:
for child in self.children:
child.apply(fn)
def register_page(page):
"""registers extra networks page for the UI
@ -156,6 +154,7 @@ def register_page(page):
allowed_dirs.clear()
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def get_page_by_name(extra_networks_tabname: str = "") -> "ExtraNetworksPage":
for page in extra_pages:
if page.extra_networks_tabname == extra_networks_tabname:
@ -186,7 +185,7 @@ def fetch_cover_images(extra_networks_tabname: str = "", item: str = "", index:
if metadata is None:
raise HTTPException(status_code=404, detail="File not found")
cover_images = json.loads(metadata.get('ssmd_cover_images', {}))
cover_images = json.loads(metadata.get("ssmd_cover_images", {}))
image = cover_images[index] if index < len(cover_images) else None
if not image:
raise HTTPException(status_code=404, detail="File not found")
@ -199,6 +198,7 @@ def fetch_cover_images(extra_networks_tabname: str = "", item: str = "", index:
except Exception as err:
raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err
def init_tree_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONResponse:
page = get_page_by_name(extra_networks_tabname)
@ -209,6 +209,7 @@ def init_tree_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONR
return JSONResponse(data, status_code=200)
def fetch_tree_data(
extra_networks_tabname: str = "",
div_ids: str = "",
@ -247,6 +248,7 @@ def init_cards_data(tabname: str = "", extra_networks_tabname: str = "") -> JSON
return JSONResponse(data, status_code=200)
def page_is_ready(extra_networks_tabname: str = "") -> JSONResponse:
page = get_page_by_name(extra_networks_tabname)
@ -259,6 +261,7 @@ def page_is_ready(extra_networks_tabname: str = "") -> JSONResponse:
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=500)
def get_metadata(extra_networks_tabname: str = "", item: str = "") -> JSONResponse:
try:
page = get_page_by_name(extra_networks_tabname)
@ -271,7 +274,7 @@ def get_metadata(extra_networks_tabname: str = "", item: str = "") -> JSONRespon
# those are cover images, and they are too big to display in UI as text
# FIXME: WHY WAS THIS HERE?
#metadata = {i: metadata[i] for i in metadata if i != 'ssmd_cover_images'}
# metadata = {i: metadata[i] for i in metadata if i != 'ssmd_cover_images'}
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
@ -294,6 +297,7 @@ def get_single_card(tabname: str = "", extra_networks_tabname: str = "", name: s
return JSONResponse({"html": item_html})
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
@ -305,8 +309,9 @@ def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/fetch-cards-data", fetch_cards_data, methods=["GET"])
app.add_api_route("/sd_extra_networks/page-is-ready", page_is_ready, methods=["GET"])
def quote_js(s):
s = s.replace('\\', '\\\\')
s = s.replace("\\", "\\\\")
s = s.replace('"', '\\"')
return f'"{s}"'
@ -359,13 +364,13 @@ class ExtraNetworksPage:
item["user_metadata"] = metadata
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
quoted_filename = urllib.parse.quote(filename.replace("\\", "/"))
mtime, _ = self.lister.mctime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename)
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
for parentdir in possible_directories if possible_directories is not None else self.allowed_directories_for_previews():
parentdir = os.path.dirname(os.path.abspath(parentdir))
if abspath.startswith(parentdir):
return os.path.relpath(abspath, parentdir)
@ -413,7 +418,7 @@ class ExtraNetworksPage:
# Boolean data attributes only need a key when true.
if v:
data_attributes_str += f"{k} "
elif v not in [None, "", "\'\'", "\"\""]:
elif v not in [None, "", "''", '""']:
data_attributes_str += f"{k}={v} "
res = self.tree_row_tpl.format(
@ -462,7 +467,6 @@ class ExtraNetworksPage:
btn_metadata=btn_metadata,
)
def create_card_html(
self,
tabname: str,
@ -545,7 +549,7 @@ class ExtraNetworksPage:
# Boolean data attributes only need a key when true.
if v:
data_attributes_str += f"{k} "
elif v not in [None, "", "\'\'", "\"\""]:
elif v not in [None, "", "''", '""']:
data_attributes_str += f"{k}={v} "
return self.card_tpl.format(
@ -563,10 +567,7 @@ class ExtraNetworksPage:
for i, item in enumerate(self.items.values()):
div_id = str(i)
card_html = self.create_card_html(tabname=tabname, item=item, div_id=div_id)
sort_keys = {
k.strip().lower().replace(" ", "_"): html.escape(str(v))
for k, v in item.get("sort_keys", {}).items()
}
sort_keys = {k.strip().lower().replace(" ", "_"): html.escape(str(v)) for k, v in item.get("sort_keys", {}).items()}
search_terms = item.get("search_terms", [])
self.cards[div_id] = CardListItem(div_id, card_html)
self.cards[div_id].sort_keys = sort_keys
@ -698,7 +699,6 @@ class ExtraNetworksPage:
}
return res
def create_dirs_view_html(self, tabname: str) -> str:
"""Generates HTML for displaying folders."""
# Flatten each root into a single dict. Only get the directories for buttons.
@ -713,13 +713,18 @@ class ExtraNetworksPage:
tree.values(),
key=lambda x: shared.natural_sort_key(x.relpath),
)
dirs_html = "".join([
self.btn_dirs_view_item_tpl.format(**{
"extra_class": "search-all" if node.relpath == "" else "",
"tabname_full": f"{tabname}_{self.extra_networks_tabname}",
"path": html.escape(node.relpath),
}) for node in dir_nodes
])
dirs_html = "".join(
[
self.btn_dirs_view_item_tpl.format(
**{
"extra_class": "search-all" if node.relpath == "" else "",
"tabname_full": f"{tabname}_{self.extra_networks_tabname}",
"path": html.escape(node.relpath),
}
)
for node in dir_nodes
]
)
return dirs_html
@ -769,22 +774,24 @@ class ExtraNetworksPage:
dirs_view_en = shared.opts.extra_networks_dirs_view_default_enabled
tree_view_en = shared.opts.extra_networks_tree_view_default_enabled
return self.pane_tpl.format(**{
"tabname": tabname,
"extra_networks_tabname": self.extra_networks_tabname,
"data_sort_dir": sort_dir,
"btn_sort_mode_path_data_attributes": "data-selected" if sort_mode == "path" else "",
"btn_sort_mode_name_data_attributes": "data-selected" if sort_mode == "name" else "",
"btn_sort_mode_date_created_data_attributes": "data-selected" if sort_mode == "date_created" else "",
"btn_sort_mode_date_modified_data_attributes": "data-selected" if sort_mode == "date_modified" else "",
"btn_dirs_view_data_attributes": "data-selected" if dirs_view_en else "",
"btn_tree_view_data_attributes": "data-selected" if tree_view_en else "",
"dirs_view_hidden_cls": "" if dirs_view_en else "hidden",
"tree_view_hidden_cls": "" if tree_view_en else "hidden",
"tree_view_style": f"flex-basis: {shared.opts.extra_networks_tree_view_default_width}px;",
"cards_view_style": "flex-grow: 1;",
"dirs_html": dirs_html,
})
return self.pane_tpl.format(
**{
"tabname": tabname,
"extra_networks_tabname": self.extra_networks_tabname,
"data_sort_dir": sort_dir,
"btn_sort_mode_path_data_attributes": "data-selected" if sort_mode == "path" else "",
"btn_sort_mode_name_data_attributes": "data-selected" if sort_mode == "name" else "",
"btn_sort_mode_date_created_data_attributes": "data-selected" if sort_mode == "date_created" else "",
"btn_sort_mode_date_modified_data_attributes": "data-selected" if sort_mode == "date_modified" else "",
"btn_dirs_view_data_attributes": "data-selected" if dirs_view_en else "",
"btn_tree_view_data_attributes": "data-selected" if tree_view_en else "",
"dirs_view_hidden_cls": "" if dirs_view_en else "hidden",
"tree_view_hidden_cls": "" if tree_view_en else "hidden",
"tree_view_style": f"flex-basis: {shared.opts.extra_networks_tree_view_default_width}px;",
"cards_view_style": "flex-grow: 1;",
"dirs_html": dirs_html,
}
)
def create_item(self, name, index=None):
raise NotImplementedError()
@ -827,7 +834,11 @@ class ExtraNetworksPage:
"""
file = f"{path}.safetensors"
if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:
if (
self.lister.exists(file)
and "ssmd_cover_images" in metadata
and len(list(filter(None, json.loads(metadata["ssmd_cover_images"])))) > 0
):
return f"./sd_extra_networks/cover-images?extra_networks_tabname={self.extra_networks_tabname}&item={name}"
return None
@ -856,9 +867,13 @@ def initialize():
def register_default_pages():
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
from modules.ui_extra_networks_checkpoints import \
ExtraNetworksPageCheckpoints
from modules.ui_extra_networks_hypernets import \
ExtraNetworksPageHypernetworks
from modules.ui_extra_networks_textual_inversion import \
ExtraNetworksPageTextualInversion
register_page(ExtraNetworksPageTextualInversion())
register_page(ExtraNetworksPageHypernetworks())
register_page(ExtraNetworksPageCheckpoints())
@ -919,8 +934,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
ui.user_metadata_editors.append(editor)
related_tabs.append(tab)
ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False)
ui.button_save_preview = gr.Button("Save preview", elem_id=f"{tabname}_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox("Preview save filename", elem_id=f"{tabname}_preview_filename", visible=False)
for tab in unrelated_tabs:
tab.select(
@ -953,14 +968,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
return ui.pages_contents
button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_{page.extra_networks_tabname}_extra_refresh_internal", visible=False)
button_refresh.click(
fn=refresh,
inputs=[],
outputs=ui.pages,
).then(
fn=lambda: None,
_js='setupAllResizeHandles'
).then(
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages,).then(fn=lambda: None, _js="setupAllResizeHandles").then(
fn=lambda: None,
_js=f"function(){{extraNetworksRefreshTab('{tabname}_{page.extra_networks_tabname}');}}",
)
@ -973,11 +981,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
create_html()
return ui.pages_contents
interface.load(
fn=pages_html,
inputs=[],
outputs=ui.pages,
).then(
interface.load(fn=pages_html, inputs=[], outputs=ui.pages,).then(
fn=lambda: None,
_js="setupAllResizeHandles",
)
@ -1014,7 +1018,7 @@ def setup_ui(ui, gallery):
is_allowed = True
break
assert is_allowed, f'writing to {filename} is not allowed'
assert is_allowed, f"writing to {filename} is not allowed"
save_image_with_geninfo(image, geninfo, filename)
@ -1024,7 +1028,7 @@ def setup_ui(ui, gallery):
fn=save_preview,
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*ui.pages]
outputs=[*ui.pages],
)
for editor in ui.user_metadata_editors: