diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index 9e44c806a..9326c6a45 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -11,8 +11,8 @@ jobs: if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: - name: Checkout Code - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: 3.11 # NB: there's no cache: pip here since we're not installing anything @@ -20,7 +20,7 @@ jobs: # not to have GHA download an (at the time of writing) 4 GB cache # of PyTorch and other dependencies. - name: Install Ruff - run: pip install ruff==0.1.6 + run: pip install ruff==0.3.3 - name: Run Ruff run: ruff . lint-js: @@ -29,9 +29,9 @@ jobs: if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Node.js - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: node-version: 18 - run: npm i --ci diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index f42e4758e..0610f4f54 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -11,9 +11,9 @@ jobs: if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.10.6 cache: pip @@ -22,7 +22,7 @@ jobs: launch.py - name: Cache models id: cache-models - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: models key: "2023-12-30" @@ -68,13 +68,13 @@ jobs: python -m coverage report -i python -m coverage html -i - name: Upload main app output - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 if: always() with: name: output path: output.txt - name: Upload coverage HTML - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 if: always() with: name: htmlcov diff --git a/.gitignore b/.gitignore index 6790e9ee7..519b4a53d 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ notification.mp3 /package-lock.json /.coverage* /test/test_outputs +/cache diff --git a/README.md b/README.md index f4cfcf290..bc08e7ad1 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-di - [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) - [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. - [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page) +- [Ascend NPUs](https://github.com/wangshuai09/stable-diffusion-webui/wiki/Install-and-run-on-Ascend-NPUs) (external wiki page) Alternatively, use online services (like Google Colab): diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 183f8bd7c..20f8df3d4 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -29,7 +29,6 @@ class NetworkOnDisk: def read_metadata(): metadata = sd_models.read_metadata_from_safetensors(filename) - metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text return metadata @@ -153,7 +152,7 @@ class NetworkModule: self.scale = weights.w["scale"].item() if "scale" in weights.w else None self.dora_scale = weights.w.get("dora_scale", None) - self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1) + self.dora_norm_dims = len(self.shape) - 1 def multiplier(self): if 'transformer' in self.sd_key[:20]: @@ -170,10 +169,22 @@ class NetworkModule: return 1.0 def apply_weight_decompose(self, updown, orig_weight): - orig_weight = orig_weight.to(updown) + # Match the device/dtype + orig_weight = orig_weight.to(updown.dtype) + dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) + updown = updown.to(orig_weight.device) + merged_scale1 = updown + orig_weight + merged_scale1_norm = ( + merged_scale1.transpose(0, 1) + .reshape(merged_scale1.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + dora_merged = ( - merged_scale1 / merged_scale1(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale + merged_scale1 * (dora_scale / merged_scale1_norm) ) final_updown = dora_merged - orig_weight return final_updown diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 7821a8a7d..1c515ebb7 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -36,13 +36,6 @@ class NetworkModuleOFT(network.NetworkModule): # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) - # LyCORIS BOFT - if self.oft_blocks.dim() == 4: - self.is_boft = True - self.rescale = weights.w.get('rescale', None) - if self.rescale is not None: - self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported @@ -54,6 +47,13 @@ class NetworkModuleOFT(network.NetworkModule): elif is_other_linear: self.out_dim = self.sd_module.embed_dim + # LyCORIS BOFT + if self.oft_blocks.dim() == 4: + self.is_boft = True + self.rescale = weights.w.get('rescale', None) + if self.rescale is not None and not is_other_linear: + self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) + self.num_blocks = self.dim self.block_size = self.out_dim // self.dim self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 3160aecfa..7a07a544e 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -149,6 +149,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) v = random.random() * max_count if count > v: + for x in "({[]})": + tag = tag.replace(x, '\\' + x) res.append(tag) return ", ".join(sorted(res)) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 66b7cc06a..5f2b5dc05 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -31,7 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": lora_on_disk.filename, "shorthash": lora_on_disk.shorthash, - "preview": self.find_preview(path), + "preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata), "description": self.find_description(path), "search_terms": search_terms, "local_preview": f"{path}.{shared.opts.samples_format}", diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 64e7a638a..7807f7f61 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -29,6 +29,7 @@ onUiLoaded(async() => { }); function getActiveTab(elements, all = false) { + if (!elements.img2imgTabs) return null; const tabs = elements.img2imgTabs.querySelectorAll("button"); if (all) return tabs; @@ -43,6 +44,7 @@ onUiLoaded(async() => { // Get tab ID function getTabId(elements) { const activeTab = getActiveTab(elements); + if (!activeTab) return null; return tabNameToElementId[activeTab.innerText]; } @@ -252,6 +254,7 @@ onUiLoaded(async() => { let isMoving = false; let mouseX, mouseY; let activeElement; + let interactedWithAltKey = false; const elements = Object.fromEntries( Object.keys(elementIDs).map(id => [ @@ -277,7 +280,7 @@ onUiLoaded(async() => { const targetElement = gradioApp().querySelector(elemId); if (!targetElement) { - console.log("Element not found"); + console.log("Element not found", elemId); return; } @@ -365,9 +368,9 @@ onUiLoaded(async() => { // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui. function fixCanvas() { - const activeTab = getActiveTab(elements).textContent.trim(); + const activeTab = getActiveTab(elements)?.textContent.trim(); - if (activeTab !== "img2img") { + if (activeTab && activeTab !== "img2img") { const img = targetElement.querySelector(`${elemId} img`); if (img && img.style.display !== "none") { @@ -508,6 +511,10 @@ onUiLoaded(async() => { if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) { e.preventDefault(); + if (hotkeysConfig.canvas_hotkey_zoom === "Alt") { + interactedWithAltKey = true; + } + let zoomPosX, zoomPosY; let delta = 0.2; if (elemData[elemId].zoomLevel > 7) { @@ -783,23 +790,29 @@ onUiLoaded(async() => { targetElement.addEventListener("mouseleave", handleMouseLeave); // Reset zoom when click on another tab - elements.img2imgTabs.addEventListener("click", resetZoom); - elements.img2imgTabs.addEventListener("click", () => { - // targetElement.style.width = ""; - if (parseInt(targetElement.style.width) > 865) { - setTimeout(fitToElement, 0); - } - }); + if (elements.img2imgTabs) { + elements.img2imgTabs.addEventListener("click", resetZoom); + elements.img2imgTabs.addEventListener("click", () => { + // targetElement.style.width = ""; + if (parseInt(targetElement.style.width) > 865) { + setTimeout(fitToElement, 0); + } + }); + } targetElement.addEventListener("wheel", e => { // change zoom level - const operation = e.deltaY > 0 ? "-" : "+"; + const operation = (e.deltaY || -e.wheelDelta) > 0 ? "-" : "+"; changeZoomLevel(operation, e); // Handle brush size adjustment with ctrl key pressed if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) { e.preventDefault(); + if (hotkeysConfig.canvas_hotkey_adjust === "Alt") { + interactedWithAltKey = true; + } + // Increase or decrease brush size based on scroll direction adjustBrushSize(elemId, e.deltaY); } @@ -839,6 +852,20 @@ onUiLoaded(async() => { document.addEventListener("keydown", handleMoveKeyDown); document.addEventListener("keyup", handleMoveKeyUp); + + // Prevent firefox from opening main menu when alt is used as a hotkey for zoom or brush size + function handleAltKeyUp(e) { + if (e.key !== "Alt" || !interactedWithAltKey) { + return; + } + + e.preventDefault(); + interactedWithAltKey = false; + } + + document.addEventListener("keyup", handleAltKeyUp); + + // Detect zoom level and update the pan speed. function updatePanPosition(movementX, movementY) { let panSpeed = 2; diff --git a/scripts/processing_autosized_crop.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_autosized_crop.py similarity index 100% rename from scripts/processing_autosized_crop.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_autosized_crop.py diff --git a/scripts/postprocessing_caption.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py similarity index 100% rename from scripts/postprocessing_caption.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py diff --git a/scripts/postprocessing_create_flipped_copies.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py similarity index 100% rename from scripts/postprocessing_create_flipped_copies.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py diff --git a/scripts/postprocessing_focal_crop.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py similarity index 100% rename from scripts/postprocessing_focal_crop.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py diff --git a/scripts/postprocessing_split_oversized.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py similarity index 100% rename from scripts/postprocessing_split_oversized.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py diff --git a/html/extra-networks-edit-item-button.html b/html/extra-networks-edit-item-button.html index 0fe43082a..fd728600f 100644 --- a/html/extra-networks-edit-item-button.html +++ b/html/extra-networks-edit-item-button.html @@ -1,4 +1,4 @@
\ No newline at end of file diff --git a/html/extra-networks-metadata-button.html b/html/extra-networks-metadata-button.html index 285b5b3b6..4ef013bc0 100644 --- a/html/extra-networks-metadata-button.html +++ b/html/extra-networks-metadata-button.html @@ -1,4 +1,4 @@ \ No newline at end of file diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index d680daf52..0c0183564 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -74,22 +74,39 @@ window.document.addEventListener('dragover', e => { e.dataTransfer.dropEffect = 'copy'; }); -window.document.addEventListener('drop', e => { +window.document.addEventListener('drop', async e => { const target = e.composedPath()[0]; - if (!eventHasFiles(e)) return; + const url = e.dataTransfer.getData('text/uri-list') || e.dataTransfer.getData('text/plain'); + if (!eventHasFiles(e) && !url) return; if (dragDropTargetIsPrompt(target)) { e.stopPropagation(); e.preventDefault(); - let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; + const isImg2img = get_tab_index('tabs') == 1; + let prompt_image_target = isImg2img ? "img2img_prompt_image" : "txt2img_prompt_image"; - const imgParent = gradioApp().getElementById(prompt_target); + const imgParent = gradioApp().getElementById(prompt_image_target); const files = e.dataTransfer.files; const fileInput = imgParent.querySelector('input[type="file"]'); - if (fileInput) { + if (eventHasFiles(e) && fileInput) { fileInput.files = files; fileInput.dispatchEvent(new Event('change')); + } else if (url) { + try { + const request = await fetch(url); + if (!request.ok) { + console.error('Error fetching URL:', url, request.status); + return; + } + const data = new DataTransfer(); + data.items.add(new File([await request.blob()], 'image.png')); + fileInput.files = data.files; + fileInput.dispatchEvent(new Event('change')); + } catch (error) { + console.error('Error fetching URL:', url, error); + return; + } } } diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 4b6559822..c210319b0 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -845,11 +845,13 @@ function extraNetworksCopyPathToClipboard(event, path) { event.stopPropagation(); } -function extraNetworksRequestMetadata(event, extraPage, cardName) { +function extraNetworksRequestMetadata(event, extraPage) { var showError = function() { extraNetworksShowMetadata("there was an error getting metadata"); }; + var cardName = event.target.parentElement.parentElement.getAttribute("data-name"); + requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) { if (data && data.metadata) { extraNetworksShowMetadata(data.metadata); @@ -873,6 +875,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { extraPageUserMetadataEditors[id] = editor; } + var cardName = event.target.parentElement.parentElement.getAttribute("data-name"); editor.nameTextarea.value = cardName; updateInput(editor.nameTextarea); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 625c5d148..d4d4f016d 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -131,19 +131,15 @@ function setupImageForLightbox(e) { e.style.cursor = 'pointer'; e.style.userSelect = 'none'; - var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1; - - // For Firefox, listening on click first switched to next image then shows the lightbox. - // If you know how to fix this without switching to mousedown event, please. - // For other browsers the event is click to make it possiblr to drag picture. - var event = isFirefox ? 'mousedown' : 'click'; - - e.addEventListener(event, function(evt) { + e.addEventListener('mousedown', function(evt) { if (evt.button == 1) { open(evt.target.src); evt.preventDefault(); return; } + }, true); + + e.addEventListener('click', function(evt) { if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); diff --git a/javascript/ui.js b/javascript/ui.js index 1eef6d337..e0f5feebd 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -136,8 +136,7 @@ function showSubmitInterruptingPlaceholder(tabname) { function showRestoreProgressButton(tabname, show) { var button = gradioApp().getElementById(tabname + "_restore_progress"); if (!button) return; - - button.style.display = show ? "flex" : "none"; + button.style.setProperty('display', show ? 'flex' : 'none', 'important'); } function submit() { @@ -209,6 +208,7 @@ function restoreProgressTxt2img() { var id = localGet("txt2img_task_id"); if (id) { + showSubmitInterruptingPlaceholder('txt2img'); requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); }, null, 0); @@ -223,6 +223,7 @@ function restoreProgressImg2img() { var id = localGet("img2img_task_id"); if (id) { + showSubmitInterruptingPlaceholder('img2img'); requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); }, null, 0); diff --git a/modules/cache.py b/modules/cache.py index a9822a0eb..f4e5f702b 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -2,48 +2,55 @@ import json import os import os.path import threading -import time + +import diskcache +import tqdm from modules.paths import data_path, script_path cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json")) -cache_data = None +cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache")) +caches = {} cache_lock = threading.Lock() -dump_cache_after = None -dump_cache_thread = None - def dump_cache(): - """ - Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written. - """ + """old function for dumping cache to disk; does nothing since diskcache.""" - global dump_cache_after - global dump_cache_thread + pass - def thread_func(): - global dump_cache_after - global dump_cache_thread - while dump_cache_after is not None and time.time() < dump_cache_after: - time.sleep(1) +def make_cache(subsection: str) -> diskcache.Cache: + return diskcache.Cache( + os.path.join(cache_dir, subsection), + size_limit=2**32, # 4 GB, culling oldest first + disk_min_file_size=2**18, # keep up to 256KB in Sqlite + ) - with cache_lock: - cache_filename_tmp = cache_filename + "-" - with open(cache_filename_tmp, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4, ensure_ascii=False) - os.replace(cache_filename_tmp, cache_filename) +def convert_old_cached_data(): + try: + with open(cache_filename, "r", encoding="utf8") as file: + data = json.load(file) + except FileNotFoundError: + return + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json') + return - dump_cache_after = None - dump_cache_thread = None + total_count = sum(len(keyvalues) for keyvalues in data.values()) - with cache_lock: - dump_cache_after = time.time() + 5 - if dump_cache_thread is None: - dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func) - dump_cache_thread.start() + with tqdm.tqdm(total=total_count, desc="converting cache") as progress: + for subsection, keyvalues in data.items(): + cache_obj = caches.get(subsection) + if cache_obj is None: + cache_obj = make_cache(subsection) + caches[subsection] = cache_obj + + for key, value in keyvalues.items(): + cache_obj[key] = value + progress.update(1) def cache(subsection): @@ -54,28 +61,21 @@ def cache(subsection): subsection (str): The subsection identifier for the cache. Returns: - dict: The cache data for the specified subsection. + diskcache.Cache: The cache data for the specified subsection. """ - global cache_data - - if cache_data is None: + cache_obj = caches.get(subsection) + if not cache_obj: with cache_lock: - if cache_data is None: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except FileNotFoundError: - cache_data = {} - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} + if not os.path.exists(cache_dir) and os.path.isfile(cache_filename): + convert_old_cached_data() - s = cache_data.get(subsection, {}) - cache_data[subsection] = s + cache_obj = caches.get(subsection) + if not cache_obj: + cache_obj = make_cache(subsection) + caches[subsection] = cache_obj - return s + return cache_obj def cached_data_for_file(subsection, title, filename, func): diff --git a/modules/extensions.py b/modules/extensions.py index 04bda297e..5ad934b4d 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,6 +1,7 @@ from __future__ import annotations import configparser +import dataclasses import os import threading import re @@ -9,6 +10,10 @@ from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 +extensions: list[Extension] = [] +extension_paths: dict[str, Extension] = {} +loaded_extensions: dict[str, Exception] = {} + os.makedirs(extensions_dir, exist_ok=True) @@ -22,6 +27,13 @@ def active(): return [x for x in extensions if x.enabled] +@dataclasses.dataclass +class CallbackOrderInfo: + name: str + before: list + after: list + + class ExtensionMetadata: filename = "metadata.ini" config: configparser.ConfigParser @@ -42,7 +54,7 @@ class ExtensionMetadata: self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = canonical_name.lower().strip() - self.requires = self.get_script_requirements("Requires", "Extension") + self.requires = None def get_script_requirements(self, field, section, extra_section=None): """reads a list of requirements from the config; field is the name of the field in the ini file, @@ -54,7 +66,15 @@ class ExtensionMetadata: if extra_section: x = x + ', ' + self.config.get(extra_section, field, fallback='') - return self.parse_list(x.lower()) + listed_requirements = self.parse_list(x.lower()) + res = [] + + for requirement in listed_requirements: + loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions) + relevant_requirement = next(loaded_requirements, requirement) + res.append(relevant_requirement) + + return res def parse_list(self, text): """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" @@ -65,6 +85,22 @@ class ExtensionMetadata: # both "," and " " are accepted as separator return [x for x in re.split(r"[,\s]+", text.strip()) if x] + def list_callback_order_instructions(self): + for section in self.config.sections(): + if not section.startswith("callbacks/"): + continue + + callback_name = section[10:] + + if not callback_name.startswith(self.canonical_name): + errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}") + continue + + before = self.parse_list(self.config.get(section, 'Before', fallback='')) + after = self.parse_list(self.config.get(section, 'After', fallback='')) + + yield CallbackOrderInfo(callback_name, before, after) + class Extension: lock = threading.Lock() @@ -156,6 +192,8 @@ class Extension: def check_updates(self): repo = Repo(self.path) for fetch in repo.remote().fetch(dry_run=True): + if self.branch and fetch.name != f'{repo.remote().name}/{self.branch}': + continue if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True self.status = "new commits" @@ -186,6 +224,8 @@ class Extension: def list_extensions(): extensions.clear() + extension_paths.clear() + loaded_extensions.clear() if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") @@ -196,7 +236,6 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - loaded_extensions = {} # scan through extensions directory and load metadata for dirname in [extensions_builtin_dir, extensions_dir]: @@ -220,8 +259,12 @@ def list_extensions(): is_builtin = dirname == extensions_builtin_dir extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) extensions.append(extension) + extension_paths[extension.path] = extension loaded_extensions[canonical_name] = extension + for extension in extensions: + extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension") + # check for requirements for extension in extensions: if not extension.enabled: @@ -238,4 +281,16 @@ def list_extensions(): continue -extensions: list[Extension] = [] +def find_extension(filename): + parentdir = os.path.dirname(os.path.realpath(filename)) + + while parentdir != filename: + extension = extension_paths.get(parentdir) + if extension is not None: + return extension + + filename = parentdir + parentdir = os.path.dirname(filename) + + return None + diff --git a/modules/img2img.py b/modules/img2img.py index e7fb3ea3c..a1d042c21 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -146,7 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal return batch_results -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -193,10 +193,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s prompt=prompt, negative_prompt=negative_prompt, styles=prompt_styles, - sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, - steps=steps, cfg_scale=cfg_scale, width=width, height=height, diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index a1cbfb17d..1c91d076d 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -265,17 +265,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model else: prompt += ("" if prompt == "" else "\n") + line - if shared.opts.infotext_styles != "Ignore": - found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) - - if shared.opts.infotext_styles == "Apply": - res["Styles array"] = found_styles - elif shared.opts.infotext_styles == "Apply if any" and found_styles: - res["Styles array"] = found_styles - - res["Prompt"] = prompt - res["Negative prompt"] = negative_prompt - for k, v in re_param.findall(lastline): try: if v[0] == '"' and v[-1] == '"': @@ -290,6 +279,26 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model except Exception: print(f"Error parsing \"{k}: {v}\"") + # Extract styles from prompt + if shared.opts.infotext_styles != "Ignore": + found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) + + same_hr_styles = True + if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True): + hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt) + hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt) + if same_hr_styles := found_styles == hr_found_styles: + res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles + res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles + + if same_hr_styles: + prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles + if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply": + res['Styles array'] = found_styles + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + # Missing CLIP skip means it was set to 1 (the default) if "Clip skip" not in res: res["Clip skip"] = "1" @@ -305,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires schedule type" not in res: + res["Hires schedule type"] = "Use same scheduler" + if "Hires checkpoint" not in res: res["Hires checkpoint"] = "Use same checkpoint" diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index b5552a312..cea676cda 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -6,6 +6,7 @@ import re v160 = version.parse("1.6.0") v170_tsnr = version.parse("v1.7.0-225") v180 = version.parse("1.8.0") +v180_hr_styles = version.parse("1.8.0-139") def parse_version(text): diff --git a/modules/initialize.py b/modules/initialize.py index 08ad4c0b0..0365bbb30 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -51,6 +51,7 @@ def check_versions(): def initialize(): from modules import initialize_util initialize_util.fix_torch_version() + initialize_util.fix_pytorch_lightning() initialize_util.fix_asyncio_event_loop_policy() initialize_util.validate_tls_options() initialize_util.configure_sigint_handler() @@ -109,7 +110,7 @@ def initialize_rest(*, reload_script_modules=False): with startup_timer.subcategory("load scripts"): scripts.load_scripts() - if reload_script_modules: + if reload_script_modules and shared.opts.enable_reloading_ui_scripts: for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: importlib.reload(module) startup_timer.record("reload script modules") diff --git a/modules/initialize_util.py b/modules/initialize_util.py index b6767138d..79a72cb3a 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -24,6 +24,13 @@ def fix_torch_version(): torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) +def fix_pytorch_lightning(): + # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache + if 'pytorch_lightning.utilities.distributed' not in sys.modules: + import pytorch_lightning + # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero + print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero") + sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero def fix_asyncio_event_loop_policy(): """ diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 6058b0cde..cf9da45ab 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -32,6 +32,6 @@ models_path = os.path.join(data_path, "models") extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") config_states_dir = os.path.join(script_path, "config_states") -default_output_dir = os.path.join(data_path, "output") +default_output_dir = os.path.join(data_path, "outputs") roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 754cc9e3a..5a4e693a8 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -66,7 +66,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, if parameters: existing_pnginfo["parameters"] = parameters - initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) + initial_pp = scripts_postprocessing.PostprocessedImage(image_data if image_data.mode in ("RGBA", "RGB") else image_data.convert("RGB")) scripts.scripts_postproc.run(initial_pp, args) @@ -122,8 +122,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, if extras_mode != 2 or show_extras_results: outputs.append(pp.image) - image_data.close() - devices.torch_gc() shared.state.end() return outputs, ui_common.plaintext_to_html(infotext), '' diff --git a/modules/processing.py b/modules/processing.py index 86194b057..2baca4f5f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -152,6 +152,7 @@ class StableDiffusionProcessing: seed_resize_from_w: int = -1 seed_enable_extras: bool = True sampler_name: str = None + scheduler: str = None batch_size: int = 1 n_iter: int = 1 steps: int = 50 @@ -702,7 +703,7 @@ def program_version(): return res -def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None): +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None, all_hr_prompts=None, all_hr_negative_prompts=None): if index is None: index = position_in_batch + iteration * p.batch_size @@ -721,6 +722,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter generation_params = { "Steps": p.steps, "Sampler": p.sampler_name, + "Schedule type": p.scheduler, "CFG scale": p.cfg_scale, "Image CFG scale": getattr(p, 'image_cfg_scale', None), "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], @@ -745,11 +747,18 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "RNG": opts.randn_source if opts.randn_source != "GPU" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "Tiling": "True" if p.tiling else None, + "Hires prompt": None, # This is set later, insert here to keep order + "Hires negative prompt": None, # This is set later, insert here to keep order **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, "User": p.user if opts.add_user_name_to_info else None, } + if all_hr_prompts := all_hr_prompts or getattr(p, 'all_hr_prompts', None): + generation_params['Hires prompt'] = all_hr_prompts[index] if all_hr_prompts[index] != all_prompts[index] else None + if all_hr_negative_prompts := all_hr_negative_prompts or getattr(p, 'all_hr_negative_prompts', None): + generation_params['Hires negative prompt'] = all_hr_negative_prompts[index] if all_hr_negative_prompts[index] != all_negative_prompts[index] else None + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] @@ -1106,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_resize_y: int = 0 hr_checkpoint_name: str = None hr_sampler_name: str = None + hr_scheduler: str = None hr_prompt: str = '' hr_negative_prompt: str = '' force_task_id: str = None @@ -1194,11 +1204,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name - if tuple(self.hr_prompt) != tuple(self.prompt): - self.extra_generation_params["Hires prompt"] = self.hr_prompt + self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py - if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): - self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + if self.hr_scheduler is None: + self.hr_scheduler = self.scheduler self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and self.latent_scale_mode is None: diff --git a/modules/processing_scripts/comments.py b/modules/processing_scripts/comments.py index 638e39f29..cf81dfd8b 100644 --- a/modules/processing_scripts/comments.py +++ b/modules/processing_scripts/comments.py @@ -26,6 +26,13 @@ class ScriptStripComments(scripts.Script): p.main_prompt = strip_comments(p.main_prompt) p.main_negative_prompt = strip_comments(p.main_negative_prompt) + if getattr(p, 'enable_hr', False): + p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts] + p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts] + + p.hr_prompt = strip_comments(p.hr_prompt) + p.hr_negative_prompt = strip_comments(p.hr_negative_prompt) + def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): if not shared.opts.enable_prompt_comments: diff --git a/modules/processing_scripts/sampler.py b/modules/processing_scripts/sampler.py new file mode 100644 index 000000000..5d50a162c --- /dev/null +++ b/modules/processing_scripts/sampler.py @@ -0,0 +1,45 @@ +import gradio as gr + +from modules import scripts, sd_samplers, sd_schedulers, shared +from modules.infotext_utils import PasteField +from modules.ui_components import FormRow, FormGroup + + +class ScriptSampler(scripts.ScriptBuiltinUI): + section = "sampler" + + def __init__(self): + self.steps = None + self.sampler_name = None + self.scheduler = None + + def title(self): + return "Sampler" + + def ui(self, is_img2img): + sampler_names = [x.name for x in sd_samplers.visible_samplers()] + scheduler_names = [x.label for x in sd_schedulers.schedulers] + + if shared.opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{self.tabname}"): + self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0]) + self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0]) + self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{self.tabname}"): + self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20) + self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0]) + self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0]) + + self.infotext_fields = [ + PasteField(self.steps, "Steps", api="steps"), + PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"), + PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"), + ] + + return self.steps, self.sampler_name, self.scheduler + + def setup(self, p, steps, sampler_name, scheduler): + p.steps = steps + p.sampler_name = sampler_name + p.scheduler = scheduler diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 08bc52564..d5a97ecff 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import dataclasses import inspect import os -from collections import namedtuple from typing import Optional, Any from fastapi import FastAPI from gradio import Blocks -from modules import errors, timer +from modules import errors, timer, extensions, shared, util def report_exception(c, job): @@ -116,7 +117,105 @@ class BeforeTokenCounterParams: is_positive: bool = True -ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) +@dataclasses.dataclass +class ScriptCallback: + script: str + callback: any + name: str = "unnamed" + + +def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None): + if filename is None: + stack = [x for x in inspect.stack() if x.filename != __file__] + filename = stack[0].filename if stack else 'unknown file' + + extension = extensions.find_extension(filename) + extension_name = extension.canonical_name if extension else 'base' + + callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}" + if name is not None: + callback_name += f'/{name}' + + unique_callback_name = callback_name + for index in range(1000): + existing = any(x.name == unique_callback_name for x in callbacks) + if not existing: + break + + unique_callback_name = f'{callback_name}-{index+1}' + + callbacks.append(ScriptCallback(filename, fun, unique_callback_name)) + + +def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True): + callbacks = unordered_callbacks.copy() + callback_lookup = {x.name: x for x in callbacks} + dependencies = {} + + order_instructions = {} + for extension in extensions.extensions: + for order_instruction in extension.metadata.list_callback_order_instructions(): + if order_instruction.name in callback_lookup: + if order_instruction.name not in order_instructions: + order_instructions[order_instruction.name] = [] + + order_instructions[order_instruction.name].append(order_instruction) + + if order_instructions: + for callback in callbacks: + dependencies[callback.name] = [] + + for callback in callbacks: + for order_instruction in order_instructions.get(callback.name, []): + for after in order_instruction.after: + if after not in callback_lookup: + continue + + dependencies[callback.name].append(after) + + for before in order_instruction.before: + if before not in callback_lookup: + continue + + dependencies[before].append(callback.name) + + sorted_names = util.topological_sort(dependencies) + callbacks = [callback_lookup[x] for x in sorted_names] + + if enable_user_sort: + for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])): + index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None) + if index is not None: + callbacks.insert(0, callbacks.pop(index)) + + return callbacks + + +def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True): + if unordered_callbacks is None: + unordered_callbacks = callback_map.get('callbacks_' + category, []) + + if not enable_user_sort: + return sort_callbacks(category, unordered_callbacks, enable_user_sort=False) + + callbacks = ordered_callbacks_map.get(category) + if callbacks is not None and len(callbacks) == len(unordered_callbacks): + return callbacks + + callbacks = sort_callbacks(category, unordered_callbacks) + + ordered_callbacks_map[category] = callbacks + return callbacks + + +def enumerate_callbacks(): + for category, callbacks in callback_map.items(): + if category.startswith('callbacks_'): + category = category[10:] + + yield category, callbacks + + callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], @@ -141,14 +240,18 @@ callback_map = dict( callbacks_before_token_counter=[], ) +ordered_callbacks_map = {} + def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() + ordered_callbacks_map.clear() + def app_started_callback(demo: Optional[Blocks], app: FastAPI): - for c in callback_map['callbacks_app_started']: + for c in ordered_callbacks('app_started'): try: c.callback(demo, app) timer.startup_timer.record(os.path.basename(c.script)) @@ -157,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_reload_callback(): - for c in callback_map['callbacks_on_reload']: + for c in ordered_callbacks('on_reload'): try: c.callback() except Exception: @@ -165,7 +268,7 @@ def app_reload_callback(): def model_loaded_callback(sd_model): - for c in callback_map['callbacks_model_loaded']: + for c in ordered_callbacks('model_loaded'): try: c.callback(sd_model) except Exception: @@ -175,7 +278,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for c in callback_map['callbacks_ui_tabs']: + for c in ordered_callbacks('ui_tabs'): try: res += c.callback() or [] except Exception: @@ -185,7 +288,7 @@ def ui_tabs_callback(): def ui_train_tabs_callback(params: UiTrainTabParams): - for c in callback_map['callbacks_ui_train_tabs']: + for c in ordered_callbacks('ui_train_tabs'): try: c.callback(params) except Exception: @@ -193,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams): def ui_settings_callback(): - for c in callback_map['callbacks_ui_settings']: + for c in ordered_callbacks('ui_settings'): try: c.callback() except Exception: @@ -201,7 +304,7 @@ def ui_settings_callback(): def before_image_saved_callback(params: ImageSaveParams): - for c in callback_map['callbacks_before_image_saved']: + for c in ordered_callbacks('before_image_saved'): try: c.callback(params) except Exception: @@ -209,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams): - for c in callback_map['callbacks_image_saved']: + for c in ordered_callbacks('image_saved'): try: c.callback(params) except Exception: @@ -217,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams): def extra_noise_callback(params: ExtraNoiseParams): - for c in callback_map['callbacks_extra_noise']: + for c in ordered_callbacks('extra_noise'): try: c.callback(params) except Exception: @@ -225,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams): def cfg_denoiser_callback(params: CFGDenoiserParams): - for c in callback_map['callbacks_cfg_denoiser']: + for c in ordered_callbacks('cfg_denoiser'): try: c.callback(params) except Exception: @@ -233,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): def cfg_denoised_callback(params: CFGDenoisedParams): - for c in callback_map['callbacks_cfg_denoised']: + for c in ordered_callbacks('cfg_denoised'): try: c.callback(params) except Exception: @@ -241,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams): def cfg_after_cfg_callback(params: AfterCFGCallbackParams): - for c in callback_map['callbacks_cfg_after_cfg']: + for c in ordered_callbacks('cfg_after_cfg'): try: c.callback(params) except Exception: @@ -249,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams): def before_component_callback(component, **kwargs): - for c in callback_map['callbacks_before_component']: + for c in ordered_callbacks('before_component'): try: c.callback(component, **kwargs) except Exception: @@ -257,7 +360,7 @@ def before_component_callback(component, **kwargs): def after_component_callback(component, **kwargs): - for c in callback_map['callbacks_after_component']: + for c in ordered_callbacks('after_component'): try: c.callback(component, **kwargs) except Exception: @@ -265,7 +368,7 @@ def after_component_callback(component, **kwargs): def image_grid_callback(params: ImageGridLoopParams): - for c in callback_map['callbacks_image_grid']: + for c in ordered_callbacks('image_grid'): try: c.callback(params) except Exception: @@ -273,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams): def infotext_pasted_callback(infotext: str, params: dict[str, Any]): - for c in callback_map['callbacks_infotext_pasted']: + for c in ordered_callbacks('infotext_pasted'): try: c.callback(infotext, params) except Exception: @@ -281,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]): def script_unloaded_callback(): - for c in reversed(callback_map['callbacks_script_unloaded']): + for c in reversed(ordered_callbacks('script_unloaded')): try: c.callback() except Exception: @@ -289,7 +392,7 @@ def script_unloaded_callback(): def before_ui_callback(): - for c in reversed(callback_map['callbacks_before_ui']): + for c in reversed(ordered_callbacks('before_ui')): try: c.callback() except Exception: @@ -299,7 +402,7 @@ def before_ui_callback(): def list_optimizers_callback(): res = [] - for c in callback_map['callbacks_list_optimizers']: + for c in ordered_callbacks('list_optimizers'): try: c.callback(res) except Exception: @@ -311,7 +414,7 @@ def list_optimizers_callback(): def list_unets_callback(): res = [] - for c in callback_map['callbacks_list_unets']: + for c in ordered_callbacks('list_unets'): try: c.callback(res) except Exception: @@ -321,20 +424,13 @@ def list_unets_callback(): def before_token_counter_callback(params: BeforeTokenCounterParams): - for c in callback_map['callbacks_before_token_counter']: + for c in ordered_callbacks('before_token_counter'): try: c.callback(params) except Exception: report_exception(c, 'before_token_counter') -def add_callback(callbacks, fun): - stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if stack else 'unknown file' - - callbacks.append(ScriptCallback(filename, fun)) - - def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -351,24 +447,24 @@ def remove_callbacks_for_function(callback_func): callback_list.remove(callback_to_remove) -def on_app_started(callback): +def on_app_started(callback, *, name=None): """register a function to be called when the webui started, the gradio `Block` component and fastapi `FastAPI` object are passed as the arguments""" - add_callback(callback_map['callbacks_app_started'], callback) + add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started') -def on_before_reload(callback): +def on_before_reload(callback, *, name=None): """register a function to be called just before the server reloads.""" - add_callback(callback_map['callbacks_on_reload'], callback) + add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload') -def on_model_loaded(callback): +def on_model_loaded(callback, *, name=None): """register a function to be called when the stable diffusion model is created; the model is passed as an argument; this function is also called when the script is reloaded. """ - add_callback(callback_map['callbacks_model_loaded'], callback) + add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded') -def on_ui_tabs(callback): +def on_ui_tabs(callback, *, name=None): """register a function to be called when the UI is creating new tabs. The function must either return a None, which means no new tabs to be added, or a list, where each element is a tuple: @@ -378,71 +474,71 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI elem_id is HTML id for the tab """ - add_callback(callback_map['callbacks_ui_tabs'], callback) + add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs') -def on_ui_train_tabs(callback): +def on_ui_train_tabs(callback, *, name=None): """register a function to be called when the UI is creating new tabs for the train tab. Create your new tabs with gr.Tab. """ - add_callback(callback_map['callbacks_ui_train_tabs'], callback) + add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs') -def on_ui_settings(callback): +def on_ui_settings(callback, *, name=None): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ - add_callback(callback_map['callbacks_ui_settings'], callback) + add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings') -def on_before_image_saved(callback): +def on_before_image_saved(callback, *, name=None): """register a function to be called before an image is saved to a file. The callback is called with one argument: - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. """ - add_callback(callback_map['callbacks_before_image_saved'], callback) + add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved') -def on_image_saved(callback): +def on_image_saved(callback, *, name=None): """register a function to be called after an image is saved to a file. The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ - add_callback(callback_map['callbacks_image_saved'], callback) + add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved') -def on_extra_noise(callback): +def on_extra_noise(callback, *, name=None): """register a function to be called before adding extra noise in img2img or hires fix; The callback is called with one argument: - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image """ - add_callback(callback_map['callbacks_extra_noise'], callback) + add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise') -def on_cfg_denoiser(callback): +def on_cfg_denoiser(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callback_map['callbacks_cfg_denoiser'], callback) + add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser') -def on_cfg_denoised(callback): +def on_cfg_denoised(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callback_map['callbacks_cfg_denoised'], callback) + add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised') -def on_cfg_after_cfg(callback): +def on_cfg_after_cfg(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. The callback is called with one argument: - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. """ - add_callback(callback_map['callbacks_cfg_after_cfg'], callback) + add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg') -def on_before_component(callback): +def on_before_component(callback, *, name=None): """register a function to be called before a component is created. The callback is called with arguments: - component - gradio component that is about to be created. @@ -451,61 +547,61 @@ def on_before_component(callback): Use elem_id/label fields of kwargs to figure out which component it is. This can be useful to inject your own components somewhere in the middle of vanilla UI. """ - add_callback(callback_map['callbacks_before_component'], callback) + add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component') -def on_after_component(callback): +def on_after_component(callback, *, name=None): """register a function to be called after a component is created. See on_before_component for more.""" - add_callback(callback_map['callbacks_after_component'], callback) + add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component') -def on_image_grid(callback): +def on_image_grid(callback, *, name=None): """register a function to be called before making an image grid. The callback is called with one argument: - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ - add_callback(callback_map['callbacks_image_grid'], callback) + add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid') -def on_infotext_pasted(callback): +def on_infotext_pasted(callback, *, name=None): """register a function to be called before applying an infotext. The callback is called with two arguments: - infotext: str - raw infotext. - result: dict[str, any] - parsed infotext parameters. """ - add_callback(callback_map['callbacks_infotext_pasted'], callback) + add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted') -def on_script_unloaded(callback): +def on_script_unloaded(callback, *, name=None): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" - add_callback(callback_map['callbacks_script_unloaded'], callback) + add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded') -def on_before_ui(callback): +def on_before_ui(callback, *, name=None): """register a function to be called before the UI is created.""" - add_callback(callback_map['callbacks_before_ui'], callback) + add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui') -def on_list_optimizers(callback): +def on_list_optimizers(callback, *, name=None): """register a function to be called when UI is making a list of cross attention optimization options. The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization to it.""" - add_callback(callback_map['callbacks_list_optimizers'], callback) + add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers') -def on_list_unets(callback): +def on_list_unets(callback, *, name=None): """register a function to be called when UI is making a list of alternative options for unet. The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" - add_callback(callback_map['callbacks_list_unets'], callback) + add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets') -def on_before_token_counter(callback): +def on_before_token_counter(callback, *, name=None): """register a function to be called when UI is counting tokens for a prompt. The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" - add_callback(callback_map['callbacks_before_token_counter'], callback) + add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter') diff --git a/modules/scripts.py b/modules/scripts.py index 77f5e4f3e..264503ca3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,9 @@ from dataclasses import dataclass import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util + +topological_sort = util.topological_sort AlwaysVisible = object() @@ -138,7 +140,6 @@ class Script: """ pass - def before_process(self, p, *args): """ This function is called very early during processing begins for AlwaysVisible scripts. @@ -351,6 +352,9 @@ class ScriptBuiltinUI(Script): return f'{tabname}{item_id}' + def show(self, is_img2img): + return AlwaysVisible + current_basedir = paths.script_path @@ -369,29 +373,6 @@ scripts_data = [] postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) -def topological_sort(dependencies): - """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. - Ignores errors relating to missing dependeencies or circular dependencies - """ - - visited = {} - result = [] - - def inner(name): - visited[name] = True - - for dep in dependencies.get(name, []): - if dep in dependencies and dep not in visited: - inner(dep) - - result.append(name) - - for depname in dependencies: - if depname not in visited: - inner(depname) - - return result - @dataclass class ScriptWithDependencies: @@ -562,6 +543,25 @@ class ScriptRunner: self.paste_field_names = [] self.inputs = [None] + self.callback_map = {} + self.callback_names = [ + 'before_process', + 'process', + 'before_process_batch', + 'after_extra_networks_activate', + 'process_batch', + 'postprocess', + 'postprocess_batch', + 'postprocess_batch_list', + 'post_sample', + 'on_mask_blend', + 'postprocess_image', + 'postprocess_maskoverlay', + 'postprocess_image_after_composite', + 'before_component', + 'after_component', + ] + self.on_before_component_elem_id = {} """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks""" @@ -600,6 +600,8 @@ class ScriptRunner: self.scripts.append(script) self.selectable_scripts.append(script) + self.callback_map.clear() + self.apply_on_before_component_callbacks() def apply_on_before_component_callbacks(self): @@ -769,8 +771,42 @@ class ScriptRunner: return processed + def list_scripts_for_method(self, method_name): + if method_name in ('before_component', 'after_component'): + return self.scripts + else: + return self.alwayson_scripts + + def create_ordered_callbacks_list(self, method_name, *, enable_user_sort=True): + script_list = self.list_scripts_for_method(method_name) + category = f'script_{method_name}' + callbacks = [] + + for script in script_list: + if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None): + continue + + script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename) + + return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort) + + def ordered_callbacks(self, method_name, *, enable_user_sort=True): + script_list = self.list_scripts_for_method(method_name) + category = f'script_{method_name}' + + scrpts_len, callbacks = self.callback_map.get(category, (-1, None)) + + if callbacks is None or scrpts_len != len(script_list): + callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort) + self.callback_map[category] = len(script_list), callbacks + + return callbacks + + def ordered_scripts(self, method_name): + return [x.callback for x in self.ordered_callbacks(method_name)] + def before_process(self, p): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('before_process'): try: script_args = p.script_args[script.args_from:script.args_to] script.before_process(p, *script_args) @@ -778,7 +814,7 @@ class ScriptRunner: errors.report(f"Error running before_process: {script.filename}", exc_info=True) def process(self, p): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('process'): try: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) @@ -786,7 +822,7 @@ class ScriptRunner: errors.report(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('before_process_batch'): try: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) @@ -794,7 +830,7 @@ class ScriptRunner: errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) def after_extra_networks_activate(self, p, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('after_extra_networks_activate'): try: script_args = p.script_args[script.args_from:script.args_to] script.after_extra_networks_activate(p, *script_args, **kwargs) @@ -802,7 +838,7 @@ class ScriptRunner: errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('process_batch'): try: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) @@ -810,7 +846,7 @@ class ScriptRunner: errors.report(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) @@ -818,7 +854,7 @@ class ScriptRunner: errors.report(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_batch'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) @@ -826,7 +862,7 @@ class ScriptRunner: errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_batch_list'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch_list(p, pp, *script_args, **kwargs) @@ -834,7 +870,7 @@ class ScriptRunner: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) def post_sample(self, p, ps: PostSampleArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('post_sample'): try: script_args = p.script_args[script.args_from:script.args_to] script.post_sample(p, ps, *script_args) @@ -842,7 +878,7 @@ class ScriptRunner: errors.report(f"Error running post_sample: {script.filename}", exc_info=True) def on_mask_blend(self, p, mba: MaskBlendArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('on_mask_blend'): try: script_args = p.script_args[script.args_from:script.args_to] script.on_mask_blend(p, mba, *script_args) @@ -850,7 +886,7 @@ class ScriptRunner: errors.report(f"Error running post_sample: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_image'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) @@ -858,7 +894,7 @@ class ScriptRunner: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_maskoverlay'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_maskoverlay(p, ppmo, *script_args) @@ -866,7 +902,7 @@ class ScriptRunner: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_image_after_composite'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image_after_composite(p, pp, *script_args) @@ -880,7 +916,7 @@ class ScriptRunner: except Exception: errors.report(f"Error running on_before_component: {script.filename}", exc_info=True) - for script in self.scripts: + for script in self.ordered_scripts('before_component'): try: script.before_component(component, **kwargs) except Exception: @@ -893,7 +929,7 @@ class ScriptRunner: except Exception: errors.report(f"Error running on_after_component: {script.filename}", exc_info=True) - for script in self.scripts: + for script in self.ordered_scripts('after_component'): try: script.after_component(component, **kwargs) except Exception: @@ -921,7 +957,7 @@ class ScriptRunner: self.scripts[si].args_to = args_to def before_hr(self, p): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('before_hr'): try: script_args = p.script_args[script.args_from:script.args_to] script.before_hr(p, *script_args) @@ -929,7 +965,7 @@ class ScriptRunner: errors.report(f"Error running before_hr: {script.filename}", exc_info=True) def setup_scrips(self, p, *, is_ui=True): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('setup'): if not is_ui and script.setup_for_ui_only: continue diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index a58528a0b..6b7b84b6d 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,7 +1,12 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared +from __future__ import annotations + +import functools + +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers # imports for functions that previously were here and are used by other modules -from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 +samples_to_image_grid = sd_samplers_common.samples_to_image_grid +sample_to_image = sd_samplers_common.sample_to_image all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, @@ -10,8 +15,8 @@ all_samplers = [ ] all_samplers_map = {x.name: x for x in all_samplers} -samplers = [] -samplers_for_img2img = [] +samplers: list[sd_samplers_common.SamplerData] = [] +samplers_for_img2img: list[sd_samplers_common.SamplerData] = [] samplers_map = {} samplers_hidden = {} @@ -57,4 +62,64 @@ def visible_sampler_names(): return [x.name for x in samplers if x.name not in samplers_hidden] +def visible_samplers(): + return [x for x in samplers if x.name not in samplers_hidden] + + +def get_sampler_from_infotext(d: dict): + return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0] + + +def get_scheduler_from_infotext(d: dict): + return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1] + + +def get_hr_sampler_and_scheduler(d: dict): + hr_sampler = d.get("Hires sampler", "Use same sampler") + sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler + + hr_scheduler = d.get("Hires schedule type", "Use same scheduler") + scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler + + sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler) + + sampler = sampler if sampler != d.get("Sampler") else "Use same sampler" + scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler" + + return sampler, scheduler + + +def get_hr_sampler_from_infotext(d: dict): + return get_hr_sampler_and_scheduler(d)[0] + + +def get_hr_scheduler_from_infotext(d: dict): + return get_hr_sampler_and_scheduler(d)[1] + + +@functools.cache +def get_sampler_and_scheduler(sampler_name, scheduler_name): + default_sampler = samplers[0] + found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0]) + + name = sampler_name or default_sampler.name + + for scheduler in sd_schedulers.schedulers: + name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])] + + for name_option in name_options: + if name.endswith(" " + name_option): + found_scheduler = scheduler + name = name[0:-(len(name_option) + 1)] + break + + sampler = all_samplers_map.get(name, default_sampler) + + # revert back to Automatic if it's the default scheduler for the selected sampler + if sampler.options.get('scheduler', None) == found_scheduler.name: + found_scheduler = sd_schedulers.schedulers[0] + + return sampler.name, found_scheduler.label + + set_samplers() diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 337106c02..b45f85b07 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,7 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback @@ -9,32 +9,20 @@ from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}), + ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), + ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}), - ('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}), - ('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), - ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}), - ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}), - ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}), ] @@ -58,12 +46,7 @@ sampler_extra_params = { } k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -k_diffusion_scheduler = { - 'Automatic': None, - 'karras': k_diffusion.sampling.get_sigmas_karras, - 'exponential': k_diffusion.sampling.get_sigmas_exponential, - 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential -} +k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers} class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): @@ -96,42 +79,43 @@ class KDiffusionSampler(sd_samplers_common.Sampler): steps += 1 if discard_next_to_last_sigma else 0 + scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic' + if scheduler_name == 'Automatic': + scheduler_name = self.config.options.get('scheduler', None) + + scheduler = sd_schedulers.schedulers_map.get(scheduler_name) + + m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item() + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) + if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) - elif opts.k_sched_type != "Automatic": - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) - sigmas_kwargs = { - 'sigma_min': sigma_min, - 'sigma_max': sigma_max, - } + elif scheduler is None or scheduler.function is None: + sigmas = self.model_wrap.get_sigmas(steps) + else: + sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} - sigmas_func = k_diffusion_scheduler[opts.k_sched_type] - p.extra_generation_params["Schedule type"] = opts.k_sched_type + if scheduler.label != 'Automatic' and not p.is_hr_pass: + p.extra_generation_params["Schedule type"] = scheduler.label + elif scheduler.label != p.extra_generation_params.get("Schedule type"): + p.extra_generation_params["Hires schedule type"] = scheduler.label - if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: + if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min: sigmas_kwargs['sigma_min'] = opts.sigma_min p.extra_generation_params["Schedule min sigma"] = opts.sigma_min - if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: + + if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max: sigmas_kwargs['sigma_max'] = opts.sigma_max p.extra_generation_params["Schedule max sigma"] = opts.sigma_max - default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. - - if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: + if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho: sigmas_kwargs['rho'] = opts.rho p.extra_generation_params["Schedule rho"] = opts.rho - sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + if scheduler.need_inner_model: + sigmas_kwargs['inner_model'] = self.model_wrap - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=shared.device) if discard_next_to_last_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py new file mode 100644 index 000000000..75eb3ac03 --- /dev/null +++ b/modules/sd_schedulers.py @@ -0,0 +1,43 @@ +import dataclasses + +import torch + +import k_diffusion + + +@dataclasses.dataclass +class Scheduler: + name: str + label: str + function: any + + default_rho: float = -1 + need_inner_model: bool = False + aliases: list = None + + +def uniform(n, sigma_min, sigma_max, inner_model, device): + return inner_model.get_sigmas(n) + + +def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): + start = inner_model.sigma_to_t(torch.tensor(sigma_max)) + end = inner_model.sigma_to_t(torch.tensor(sigma_min)) + sigs = [ + inner_model.t_to_sigma(ts) + for ts in torch.linspace(start, end, n + 1)[:-1] + ] + sigs += [0.0] + return torch.FloatTensor(sigs).to(device) + + +schedulers = [ + Scheduler('automatic', 'Automatic', None), + Scheduler('uniform', 'Uniform', uniform, need_inner_model=True), + Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0), + Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential), + Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0), + Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]), +] + +schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} diff --git a/modules/shared.py b/modules/shared.py index b4ba14ad7..4cf7f6a81 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -6,6 +6,10 @@ import gradio as gr from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules import util +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon cmd_opts = shared_cmd_options.cmd_opts parser = shared_cmd_options.parser @@ -16,11 +20,11 @@ styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.st config_filename = cmd_opts.ui_settings_file hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} -demo = None +demo: gr.Blocks = None -device = None +device: str = None -weight_load_location = None +weight_load_location: str = None xformers_available = False @@ -28,21 +32,21 @@ hypernetworks = {} loaded_hypernetworks = [] -state = None +state: 'shared_state.State' = None -prompt_styles = None +prompt_styles: 'styles.StyleDatabase' = None -interrogator = None +interrogator: 'interrogate.InterrogateModels' = None face_restorers = [] -options_templates = None -opts = None -restricted_opts = None +options_templates: dict = None +opts: options.Options = None +restricted_opts: set[str] = None sd_model: sd_models_types.WebuiSdModel = None -settings_components = None +settings_components: dict = None """assigned from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" tab_names = [] @@ -65,9 +69,9 @@ progress_print_out = sys.stdout gradio_theme = gr.themes.Base() -total_tqdm = None +total_tqdm: 'shared_total_tqdm.TotalTQDM' = None -mem_mon = None +mem_mon: 'memmon.MemUsageMonitor' = None options_section = options.options_section OptionInfo = options.OptionInfo diff --git a/modules/shared_items.py b/modules/shared_items.py index 88f636452..11f10b3f7 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -1,5 +1,8 @@ +import html import sys +from modules import script_callbacks, scripts, ui_components +from modules.options import OptionHTML, OptionInfo from modules.shared_cmd_options import cmd_opts @@ -118,6 +121,45 @@ def ui_reorder_categories(): yield "scripts" +def callbacks_order_settings(): + options = { + "sd_vae_explanation": OptionHTML(""" + For categories below, callbacks added to dropdowns happen before others, in order listed. + """), + + } + + callback_options = {} + + for category, _ in script_callbacks.enumerate_callbacks(): + callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False) + + for method_name in scripts.scripts_txt2img.callback_names: + callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False) + + for method_name in scripts.scripts_img2img.callback_names: + callbacks = callback_options.get("script_" + method_name, []) + + for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False): + if any(x.name == addition.name for x in callbacks): + continue + + callbacks.append(addition) + + callback_options["script_" + method_name] = callbacks + + for category, callbacks in callback_options.items(): + if not callbacks: + continue + + option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]}) + option_info.needs_restart() + option_info.html("