merge dev into branch

This commit is contained in:
Sj-Si 2024-03-26 14:14:13 -04:00
commit d88e91c508
58 changed files with 926 additions and 371 deletions

View file

@ -2,48 +2,55 @@ import json
import os
import os.path
import threading
import time
import diskcache
import tqdm
from modules.paths import data_path, script_path
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
cache_data = None
cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache"))
caches = {}
cache_lock = threading.Lock()
dump_cache_after = None
dump_cache_thread = None
def dump_cache():
"""
Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
"""
"""old function for dumping cache to disk; does nothing since diskcache."""
global dump_cache_after
global dump_cache_thread
pass
def thread_func():
global dump_cache_after
global dump_cache_thread
while dump_cache_after is not None and time.time() < dump_cache_after:
time.sleep(1)
def make_cache(subsection: str) -> diskcache.Cache:
return diskcache.Cache(
os.path.join(cache_dir, subsection),
size_limit=2**32, # 4 GB, culling oldest first
disk_min_file_size=2**18, # keep up to 256KB in Sqlite
)
with cache_lock:
cache_filename_tmp = cache_filename + "-"
with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4, ensure_ascii=False)
os.replace(cache_filename_tmp, cache_filename)
def convert_old_cached_data():
try:
with open(cache_filename, "r", encoding="utf8") as file:
data = json.load(file)
except FileNotFoundError:
return
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json')
return
dump_cache_after = None
dump_cache_thread = None
total_count = sum(len(keyvalues) for keyvalues in data.values())
with cache_lock:
dump_cache_after = time.time() + 5
if dump_cache_thread is None:
dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
dump_cache_thread.start()
with tqdm.tqdm(total=total_count, desc="converting cache") as progress:
for subsection, keyvalues in data.items():
cache_obj = caches.get(subsection)
if cache_obj is None:
cache_obj = make_cache(subsection)
caches[subsection] = cache_obj
for key, value in keyvalues.items():
cache_obj[key] = value
progress.update(1)
def cache(subsection):
@ -54,28 +61,21 @@ def cache(subsection):
subsection (str): The subsection identifier for the cache.
Returns:
dict: The cache data for the specified subsection.
diskcache.Cache: The cache data for the specified subsection.
"""
global cache_data
if cache_data is None:
cache_obj = caches.get(subsection)
if not cache_obj:
with cache_lock:
if cache_data is None:
try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except FileNotFoundError:
cache_data = {}
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {}
if not os.path.exists(cache_dir) and os.path.isfile(cache_filename):
convert_old_cached_data()
s = cache_data.get(subsection, {})
cache_data[subsection] = s
cache_obj = caches.get(subsection)
if not cache_obj:
cache_obj = make_cache(subsection)
caches[subsection] = cache_obj
return s
return cache_obj
def cached_data_for_file(subsection, title, filename, func):

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import configparser
import dataclasses
import os
import threading
import re
@ -9,6 +10,10 @@ from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}
loaded_extensions: dict[str, Exception] = {}
os.makedirs(extensions_dir, exist_ok=True)
@ -22,6 +27,13 @@ def active():
return [x for x in extensions if x.enabled]
@dataclasses.dataclass
class CallbackOrderInfo:
name: str
before: list
after: list
class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
@ -42,7 +54,7 @@ class ExtensionMetadata:
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip()
self.requires = self.get_script_requirements("Requires", "Extension")
self.requires = None
def get_script_requirements(self, field, section, extra_section=None):
"""reads a list of requirements from the config; field is the name of the field in the ini file,
@ -54,7 +66,15 @@ class ExtensionMetadata:
if extra_section:
x = x + ', ' + self.config.get(extra_section, field, fallback='')
return self.parse_list(x.lower())
listed_requirements = self.parse_list(x.lower())
res = []
for requirement in listed_requirements:
loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions)
relevant_requirement = next(loaded_requirements, requirement)
res.append(relevant_requirement)
return res
def parse_list(self, text):
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
@ -65,6 +85,22 @@ class ExtensionMetadata:
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
def list_callback_order_instructions(self):
for section in self.config.sections():
if not section.startswith("callbacks/"):
continue
callback_name = section[10:]
if not callback_name.startswith(self.canonical_name):
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
continue
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
after = self.parse_list(self.config.get(section, 'After', fallback=''))
yield CallbackOrderInfo(callback_name, before, after)
class Extension:
lock = threading.Lock()
@ -156,6 +192,8 @@ class Extension:
def check_updates(self):
repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True):
if self.branch and fetch.name != f'{repo.remote().name}/{self.branch}':
continue
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "new commits"
@ -186,6 +224,8 @@ class Extension:
def list_extensions():
extensions.clear()
extension_paths.clear()
loaded_extensions.clear()
if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@ -196,7 +236,6 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
loaded_extensions = {}
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
@ -220,8 +259,12 @@ def list_extensions():
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
extension_paths[extension.path] = extension
loaded_extensions[canonical_name] = extension
for extension in extensions:
extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")
# check for requirements
for extension in extensions:
if not extension.enabled:
@ -238,4 +281,16 @@ def list_extensions():
continue
extensions: list[Extension] = []
def find_extension(filename):
parentdir = os.path.dirname(os.path.realpath(filename))
while parentdir != filename:
extension = extension_paths.get(parentdir)
if extension is not None:
return extension
filename = parentdir
parentdir = os.path.dirname(filename)
return None

View file

@ -146,7 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@ -193,10 +193,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
prompt=prompt,
negative_prompt=negative_prompt,
styles=prompt_styles,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,

View file

@ -265,17 +265,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
if shared.opts.infotext_styles != "Ignore":
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
if shared.opts.infotext_styles == "Apply":
res["Styles array"] = found_styles
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
res["Styles array"] = found_styles
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
try:
if v[0] == '"' and v[-1] == '"':
@ -290,6 +279,26 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# Extract styles from prompt
if shared.opts.infotext_styles != "Ignore":
found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
same_hr_styles = True
if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
if same_hr_styles := found_styles == hr_found_styles:
res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
if same_hr_styles:
prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
res['Styles array'] = found_styles
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
@ -305,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler"
if "Hires schedule type" not in res:
res["Hires schedule type"] = "Use same scheduler"
if "Hires checkpoint" not in res:
res["Hires checkpoint"] = "Use same checkpoint"

View file

@ -6,6 +6,7 @@ import re
v160 = version.parse("1.6.0")
v170_tsnr = version.parse("v1.7.0-225")
v180 = version.parse("1.8.0")
v180_hr_styles = version.parse("1.8.0-139")
def parse_version(text):

View file

@ -51,6 +51,7 @@ def check_versions():
def initialize():
from modules import initialize_util
initialize_util.fix_torch_version()
initialize_util.fix_pytorch_lightning()
initialize_util.fix_asyncio_event_loop_policy()
initialize_util.validate_tls_options()
initialize_util.configure_sigint_handler()
@ -109,7 +110,7 @@ def initialize_rest(*, reload_script_modules=False):
with startup_timer.subcategory("load scripts"):
scripts.load_scripts()
if reload_script_modules:
if reload_script_modules and shared.opts.enable_reloading_ui_scripts:
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
startup_timer.record("reload script modules")

View file

@ -24,6 +24,13 @@ def fix_torch_version():
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
def fix_pytorch_lightning():
# Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
if 'pytorch_lightning.utilities.distributed' not in sys.modules:
import pytorch_lightning
# Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
def fix_asyncio_event_loop_policy():
"""

View file

@ -32,6 +32,6 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
default_output_dir = os.path.join(data_path, "output")
default_output_dir = os.path.join(data_path, "outputs")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')

View file

@ -66,7 +66,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if parameters:
existing_pnginfo["parameters"] = parameters
initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
initial_pp = scripts_postprocessing.PostprocessedImage(image_data if image_data.mode in ("RGBA", "RGB") else image_data.convert("RGB"))
scripts.scripts_postproc.run(initial_pp, args)
@ -122,8 +122,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
image_data.close()
devices.torch_gc()
shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), ''

View file

@ -152,6 +152,7 @@ class StableDiffusionProcessing:
seed_resize_from_w: int = -1
seed_enable_extras: bool = True
sampler_name: str = None
scheduler: str = None
batch_size: int = 1
n_iter: int = 1
steps: int = 50
@ -702,7 +703,7 @@ def program_version():
return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None, all_hr_prompts=None, all_hr_negative_prompts=None):
if index is None:
index = position_in_batch + iteration * p.batch_size
@ -721,6 +722,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params = {
"Steps": p.steps,
"Sampler": p.sampler_name,
"Schedule type": p.scheduler,
"CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
@ -745,11 +747,18 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
"Tiling": "True" if p.tiling else None,
"Hires prompt": None, # This is set later, insert here to keep order
"Hires negative prompt": None, # This is set later, insert here to keep order
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
"User": p.user if opts.add_user_name_to_info else None,
}
if all_hr_prompts := all_hr_prompts or getattr(p, 'all_hr_prompts', None):
generation_params['Hires prompt'] = all_hr_prompts[index] if all_hr_prompts[index] != all_prompts[index] else None
if all_hr_negative_prompts := all_hr_negative_prompts or getattr(p, 'all_hr_negative_prompts', None):
generation_params['Hires negative prompt'] = all_hr_negative_prompts[index] if all_hr_negative_prompts[index] != all_negative_prompts[index] else None
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
@ -1106,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_resize_y: int = 0
hr_checkpoint_name: str = None
hr_sampler_name: str = None
hr_scheduler: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
force_task_id: str = None
@ -1194,11 +1204,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
if tuple(self.hr_prompt) != tuple(self.prompt):
self.extra_generation_params["Hires prompt"] = self.hr_prompt
self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
if self.hr_scheduler is None:
self.hr_scheduler = self.scheduler
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and self.latent_scale_mode is None:

View file

@ -26,6 +26,13 @@ class ScriptStripComments(scripts.Script):
p.main_prompt = strip_comments(p.main_prompt)
p.main_negative_prompt = strip_comments(p.main_negative_prompt)
if getattr(p, 'enable_hr', False):
p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts]
p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts]
p.hr_prompt = strip_comments(p.hr_prompt)
p.hr_negative_prompt = strip_comments(p.hr_negative_prompt)
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
if not shared.opts.enable_prompt_comments:

View file

@ -0,0 +1,45 @@
import gradio as gr
from modules import scripts, sd_samplers, sd_schedulers, shared
from modules.infotext_utils import PasteField
from modules.ui_components import FormRow, FormGroup
class ScriptSampler(scripts.ScriptBuiltinUI):
section = "sampler"
def __init__(self):
self.steps = None
self.sampler_name = None
self.scheduler = None
def title(self):
return "Sampler"
def ui(self, is_img2img):
sampler_names = [x.name for x in sd_samplers.visible_samplers()]
scheduler_names = [x.label for x in sd_schedulers.schedulers]
if shared.opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{self.tabname}"):
self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{self.tabname}"):
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
self.infotext_fields = [
PasteField(self.steps, "Steps", api="steps"),
PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
]
return self.steps, self.sampler_name, self.scheduler
def setup(self, p, steps, sampler_name, scheduler):
p.steps = steps
p.sampler_name = sampler_name
p.scheduler = scheduler

View file

@ -1,13 +1,14 @@
from __future__ import annotations
import dataclasses
import inspect
import os
from collections import namedtuple
from typing import Optional, Any
from fastapi import FastAPI
from gradio import Blocks
from modules import errors, timer
from modules import errors, timer, extensions, shared, util
def report_exception(c, job):
@ -116,7 +117,105 @@ class BeforeTokenCounterParams:
is_positive: bool = True
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
@dataclasses.dataclass
class ScriptCallback:
script: str
callback: any
name: str = "unnamed"
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
if filename is None:
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
extension = extensions.find_extension(filename)
extension_name = extension.canonical_name if extension else 'base'
callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
if name is not None:
callback_name += f'/{name}'
unique_callback_name = callback_name
for index in range(1000):
existing = any(x.name == unique_callback_name for x in callbacks)
if not existing:
break
unique_callback_name = f'{callback_name}-{index+1}'
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
callbacks = unordered_callbacks.copy()
callback_lookup = {x.name: x for x in callbacks}
dependencies = {}
order_instructions = {}
for extension in extensions.extensions:
for order_instruction in extension.metadata.list_callback_order_instructions():
if order_instruction.name in callback_lookup:
if order_instruction.name not in order_instructions:
order_instructions[order_instruction.name] = []
order_instructions[order_instruction.name].append(order_instruction)
if order_instructions:
for callback in callbacks:
dependencies[callback.name] = []
for callback in callbacks:
for order_instruction in order_instructions.get(callback.name, []):
for after in order_instruction.after:
if after not in callback_lookup:
continue
dependencies[callback.name].append(after)
for before in order_instruction.before:
if before not in callback_lookup:
continue
dependencies[before].append(callback.name)
sorted_names = util.topological_sort(dependencies)
callbacks = [callback_lookup[x] for x in sorted_names]
if enable_user_sort:
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
if index is not None:
callbacks.insert(0, callbacks.pop(index))
return callbacks
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
if unordered_callbacks is None:
unordered_callbacks = callback_map.get('callbacks_' + category, [])
if not enable_user_sort:
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
callbacks = ordered_callbacks_map.get(category)
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
return callbacks
callbacks = sort_callbacks(category, unordered_callbacks)
ordered_callbacks_map[category] = callbacks
return callbacks
def enumerate_callbacks():
for category, callbacks in callback_map.items():
if category.startswith('callbacks_'):
category = category[10:]
yield category, callbacks
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
@ -141,14 +240,18 @@ callback_map = dict(
callbacks_before_token_counter=[],
)
ordered_callbacks_map = {}
def clear_callbacks():
for callback_list in callback_map.values():
callback_list.clear()
ordered_callbacks_map.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
for c in ordered_callbacks('app_started'):
try:
c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script))
@ -157,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def app_reload_callback():
for c in callback_map['callbacks_on_reload']:
for c in ordered_callbacks('on_reload'):
try:
c.callback()
except Exception:
@ -165,7 +268,7 @@ def app_reload_callback():
def model_loaded_callback(sd_model):
for c in callback_map['callbacks_model_loaded']:
for c in ordered_callbacks('model_loaded'):
try:
c.callback(sd_model)
except Exception:
@ -175,7 +278,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
for c in callback_map['callbacks_ui_tabs']:
for c in ordered_callbacks('ui_tabs'):
try:
res += c.callback() or []
except Exception:
@ -185,7 +288,7 @@ def ui_tabs_callback():
def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']:
for c in ordered_callbacks('ui_train_tabs'):
try:
c.callback(params)
except Exception:
@ -193,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']:
for c in ordered_callbacks('ui_settings'):
try:
c.callback()
except Exception:
@ -201,7 +304,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_before_image_saved']:
for c in ordered_callbacks('before_image_saved'):
try:
c.callback(params)
except Exception:
@ -209,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_image_saved']:
for c in ordered_callbacks('image_saved'):
try:
c.callback(params)
except Exception:
@ -217,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams):
def extra_noise_callback(params: ExtraNoiseParams):
for c in callback_map['callbacks_extra_noise']:
for c in ordered_callbacks('extra_noise'):
try:
c.callback(params)
except Exception:
@ -225,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
for c in ordered_callbacks('cfg_denoiser'):
try:
c.callback(params)
except Exception:
@ -233,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
def cfg_denoised_callback(params: CFGDenoisedParams):
for c in callback_map['callbacks_cfg_denoised']:
for c in ordered_callbacks('cfg_denoised'):
try:
c.callback(params)
except Exception:
@ -241,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']:
for c in ordered_callbacks('cfg_after_cfg'):
try:
c.callback(params)
except Exception:
@ -249,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
for c in ordered_callbacks('before_component'):
try:
c.callback(component, **kwargs)
except Exception:
@ -257,7 +360,7 @@ def before_component_callback(component, **kwargs):
def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']:
for c in ordered_callbacks('after_component'):
try:
c.callback(component, **kwargs)
except Exception:
@ -265,7 +368,7 @@ def after_component_callback(component, **kwargs):
def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']:
for c in ordered_callbacks('image_grid'):
try:
c.callback(params)
except Exception:
@ -273,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams):
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']:
for c in ordered_callbacks('infotext_pasted'):
try:
c.callback(infotext, params)
except Exception:
@ -281,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']):
for c in reversed(ordered_callbacks('script_unloaded')):
try:
c.callback()
except Exception:
@ -289,7 +392,7 @@ def script_unloaded_callback():
def before_ui_callback():
for c in reversed(callback_map['callbacks_before_ui']):
for c in reversed(ordered_callbacks('before_ui')):
try:
c.callback()
except Exception:
@ -299,7 +402,7 @@ def before_ui_callback():
def list_optimizers_callback():
res = []
for c in callback_map['callbacks_list_optimizers']:
for c in ordered_callbacks('list_optimizers'):
try:
c.callback(res)
except Exception:
@ -311,7 +414,7 @@ def list_optimizers_callback():
def list_unets_callback():
res = []
for c in callback_map['callbacks_list_unets']:
for c in ordered_callbacks('list_unets'):
try:
c.callback(res)
except Exception:
@ -321,20 +424,13 @@ def list_unets_callback():
def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']:
for c in ordered_callbacks('before_token_counter'):
try:
c.callback(params)
except Exception:
report_exception(c, 'before_token_counter')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
@ -351,24 +447,24 @@ def remove_callbacks_for_function(callback_func):
callback_list.remove(callback_to_remove)
def on_app_started(callback):
def on_app_started(callback, *, name=None):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callback_map['callbacks_app_started'], callback)
add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
def on_before_reload(callback):
def on_before_reload(callback, *, name=None):
"""register a function to be called just before the server reloads."""
add_callback(callback_map['callbacks_on_reload'], callback)
add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
def on_model_loaded(callback):
def on_model_loaded(callback, *, name=None):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument; this function is also called when the script is reloaded. """
add_callback(callback_map['callbacks_model_loaded'], callback)
add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
def on_ui_tabs(callback):
def on_ui_tabs(callback, *, name=None):
"""register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple:
@ -378,71 +474,71 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
add_callback(callback_map['callbacks_ui_tabs'], callback)
add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
def on_ui_train_tabs(callback):
def on_ui_train_tabs(callback, *, name=None):
"""register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab.
"""
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
def on_ui_settings(callback):
def on_ui_settings(callback, *, name=None):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callback_map['callbacks_ui_settings'], callback)
add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
def on_before_image_saved(callback):
def on_before_image_saved(callback, *, name=None):
"""register a function to be called before an image is saved to a file.
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
add_callback(callback_map['callbacks_before_image_saved'], callback)
add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
def on_image_saved(callback):
def on_image_saved(callback, *, name=None):
"""register a function to be called after an image is saved to a file.
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callback_map['callbacks_image_saved'], callback)
add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
def on_extra_noise(callback):
def on_extra_noise(callback, *, name=None):
"""register a function to be called before adding extra noise in img2img or hires fix;
The callback is called with one argument:
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
"""
add_callback(callback_map['callbacks_extra_noise'], callback)
add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
def on_cfg_denoiser(callback):
def on_cfg_denoiser(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
def on_cfg_denoised(callback):
def on_cfg_denoised(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callback_map['callbacks_cfg_denoised'], callback)
add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
def on_cfg_after_cfg(callback):
def on_cfg_after_cfg(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
The callback is called with one argument:
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
"""
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
def on_before_component(callback):
def on_before_component(callback, *, name=None):
"""register a function to be called before a component is created.
The callback is called with arguments:
- component - gradio component that is about to be created.
@ -451,61 +547,61 @@ def on_before_component(callback):
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
"""
add_callback(callback_map['callbacks_before_component'], callback)
add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
def on_after_component(callback):
def on_after_component(callback, *, name=None):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
def on_image_grid(callback):
def on_image_grid(callback, *, name=None):
"""register a function to be called before making an image grid.
The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)
add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
def on_infotext_pasted(callback):
def on_infotext_pasted(callback, *, name=None):
"""register a function to be called before applying an infotext.
The callback is called with two arguments:
- infotext: str - raw infotext.
- result: dict[str, any] - parsed infotext parameters.
"""
add_callback(callback_map['callbacks_infotext_pasted'], callback)
add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
def on_script_unloaded(callback):
def on_script_unloaded(callback, *, name=None):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback)
add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
def on_before_ui(callback):
def on_before_ui(callback, *, name=None):
"""register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback)
add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
def on_list_optimizers(callback):
def on_list_optimizers(callback, *, name=None):
"""register a function to be called when UI is making a list of cross attention optimization options.
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
def on_list_unets(callback):
def on_list_unets(callback, *, name=None):
"""register a function to be called when UI is making a list of alternative options for unet.
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback)
add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
def on_before_token_counter(callback):
def on_before_token_counter(callback, *, name=None):
"""register a function to be called when UI is counting tokens for a prompt.
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
add_callback(callback_map['callbacks_before_token_counter'], callback)
add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')

View file

@ -7,7 +7,9 @@ from dataclasses import dataclass
import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
topological_sort = util.topological_sort
AlwaysVisible = object()
@ -138,7 +140,6 @@ class Script:
"""
pass
def before_process(self, p, *args):
"""
This function is called very early during processing begins for AlwaysVisible scripts.
@ -351,6 +352,9 @@ class ScriptBuiltinUI(Script):
return f'{tabname}{item_id}'
def show(self, is_img2img):
return AlwaysVisible
current_basedir = paths.script_path
@ -369,29 +373,6 @@ scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass
class ScriptWithDependencies:
@ -562,6 +543,25 @@ class ScriptRunner:
self.paste_field_names = []
self.inputs = [None]
self.callback_map = {}
self.callback_names = [
'before_process',
'process',
'before_process_batch',
'after_extra_networks_activate',
'process_batch',
'postprocess',
'postprocess_batch',
'postprocess_batch_list',
'post_sample',
'on_mask_blend',
'postprocess_image',
'postprocess_maskoverlay',
'postprocess_image_after_composite',
'before_component',
'after_component',
]
self.on_before_component_elem_id = {}
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
@ -600,6 +600,8 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
self.callback_map.clear()
self.apply_on_before_component_callbacks()
def apply_on_before_component_callbacks(self):
@ -769,8 +771,42 @@ class ScriptRunner:
return processed
def list_scripts_for_method(self, method_name):
if method_name in ('before_component', 'after_component'):
return self.scripts
else:
return self.alwayson_scripts
def create_ordered_callbacks_list(self, method_name, *, enable_user_sort=True):
script_list = self.list_scripts_for_method(method_name)
category = f'script_{method_name}'
callbacks = []
for script in script_list:
if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):
continue
script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)
return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)
def ordered_callbacks(self, method_name, *, enable_user_sort=True):
script_list = self.list_scripts_for_method(method_name)
category = f'script_{method_name}'
scrpts_len, callbacks = self.callback_map.get(category, (-1, None))
if callbacks is None or scrpts_len != len(script_list):
callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)
self.callback_map[category] = len(script_list), callbacks
return callbacks
def ordered_scripts(self, method_name):
return [x.callback for x in self.ordered_callbacks(method_name)]
def before_process(self, p):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('before_process'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process(p, *script_args)
@ -778,7 +814,7 @@ class ScriptRunner:
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
def process(self, p):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('process'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
@ -786,7 +822,7 @@ class ScriptRunner:
errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('before_process_batch'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
@ -794,7 +830,7 @@ class ScriptRunner:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def after_extra_networks_activate(self, p, **kwargs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('after_extra_networks_activate'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.after_extra_networks_activate(p, *script_args, **kwargs)
@ -802,7 +838,7 @@ class ScriptRunner:
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('process_batch'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
@ -810,7 +846,7 @@ class ScriptRunner:
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('postprocess'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
@ -818,7 +854,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('postprocess_batch'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
@ -826,7 +862,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('postprocess_batch_list'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
@ -834,7 +870,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('post_sample'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
@ -842,7 +878,7 @@ class ScriptRunner:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('on_mask_blend'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
@ -850,7 +886,7 @@ class ScriptRunner:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('postprocess_image'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
@ -858,7 +894,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('postprocess_maskoverlay'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
@ -866,7 +902,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('postprocess_image_after_composite'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image_after_composite(p, pp, *script_args)
@ -880,7 +916,7 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
for script in self.scripts:
for script in self.ordered_scripts('before_component'):
try:
script.before_component(component, **kwargs)
except Exception:
@ -893,7 +929,7 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
for script in self.scripts:
for script in self.ordered_scripts('after_component'):
try:
script.after_component(component, **kwargs)
except Exception:
@ -921,7 +957,7 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
def before_hr(self, p):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('before_hr'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_hr(p, *script_args)
@ -929,7 +965,7 @@ class ScriptRunner:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts:
for script in self.ordered_scripts('setup'):
if not is_ui and script.setup_for_ui_only:
continue

View file

@ -1,7 +1,12 @@
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
from __future__ import annotations
import functools
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
sample_to_image = sd_samplers_common.sample_to_image
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
@ -10,8 +15,8 @@ all_samplers = [
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers: list[sd_samplers_common.SamplerData] = []
samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
samplers_map = {}
samplers_hidden = {}
@ -57,4 +62,64 @@ def visible_sampler_names():
return [x.name for x in samplers if x.name not in samplers_hidden]
def visible_samplers():
return [x for x in samplers if x.name not in samplers_hidden]
def get_sampler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
def get_scheduler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
def get_hr_sampler_and_scheduler(d: dict):
hr_sampler = d.get("Hires sampler", "Use same sampler")
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
return sampler, scheduler
def get_hr_sampler_from_infotext(d: dict):
return get_hr_sampler_and_scheduler(d)[0]
def get_hr_scheduler_from_infotext(d: dict):
return get_hr_sampler_and_scheduler(d)[1]
@functools.cache
def get_sampler_and_scheduler(sampler_name, scheduler_name):
default_sampler = samplers[0]
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
name = sampler_name or default_sampler.name
for scheduler in sd_schedulers.schedulers:
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
for name_option in name_options:
if name.endswith(" " + name_option):
found_scheduler = scheduler
name = name[0:-(len(name_option) + 1)]
break
sampler = all_samplers_map.get(name, default_sampler)
# revert back to Automatic if it's the default scheduler for the selected sampler
if sampler.options.get('scheduler', None) == found_scheduler.name:
found_scheduler = sd_schedulers.schedulers[0]
return sampler.name, found_scheduler.label
set_samplers()

View file

@ -1,7 +1,7 @@
import torch
import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
@ -9,32 +9,20 @@ from modules.shared import opts
import modules.shared as shared
samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
]
@ -58,12 +46,7 @@ sampler_extra_params = {
}
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {
'Automatic': None,
'karras': k_diffusion.sampling.get_sigmas_karras,
'exponential': k_diffusion.sampling.get_sigmas_exponential,
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
}
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@ -96,42 +79,43 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
steps += 1 if discard_next_to_last_sigma else 0
scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
if scheduler_name == 'Automatic':
scheduler_name = self.config.options.get('scheduler', None)
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif opts.k_sched_type != "Automatic":
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
sigmas_kwargs = {
'sigma_min': sigma_min,
'sigma_max': sigma_max,
}
elif scheduler is None or scheduler.function is None:
sigmas = self.model_wrap.get_sigmas(steps)
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
p.extra_generation_params["Schedule type"] = opts.k_sched_type
if scheduler.label != 'Automatic' and not p.is_hr_pass:
p.extra_generation_params["Schedule type"] = scheduler.label
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
p.extra_generation_params["Hires schedule type"] = scheduler.label
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
sigmas_kwargs['sigma_min'] = opts.sigma_min
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:
sigmas_kwargs['sigma_max'] = opts.sigma_max
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:
sigmas_kwargs['rho'] = opts.rho
p.extra_generation_params["Schedule rho"] = opts.rho
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
if scheduler.need_inner_model:
sigmas_kwargs['inner_model'] = self.model_wrap
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=shared.device)
if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])

43
modules/sd_schedulers.py Normal file
View file

@ -0,0 +1,43 @@
import dataclasses
import torch
import k_diffusion
@dataclasses.dataclass
class Scheduler:
name: str
label: str
function: any
default_rho: float = -1
need_inner_model: bool = False
aliases: list = None
def uniform(n, sigma_min, sigma_max, inner_model, device):
return inner_model.get_sigmas(n)
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
sigs = [
inner_model.t_to_sigma(ts)
for ts in torch.linspace(start, end, n + 1)[:-1]
]
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)
schedulers = [
Scheduler('automatic', 'Automatic', None),
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
]
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

View file

@ -6,6 +6,10 @@ import gradio as gr
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import util
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon
cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser
@ -16,11 +20,11 @@ styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.st
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
demo = None
demo: gr.Blocks = None
device = None
device: str = None
weight_load_location = None
weight_load_location: str = None
xformers_available = False
@ -28,21 +32,21 @@ hypernetworks = {}
loaded_hypernetworks = []
state = None
state: 'shared_state.State' = None
prompt_styles = None
prompt_styles: 'styles.StyleDatabase' = None
interrogator = None
interrogator: 'interrogate.InterrogateModels' = None
face_restorers = []
options_templates = None
opts = None
restricted_opts = None
options_templates: dict = None
opts: options.Options = None
restricted_opts: set[str] = None
sd_model: sd_models_types.WebuiSdModel = None
settings_components = None
settings_components: dict = None
"""assigned from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
tab_names = []
@ -65,9 +69,9 @@ progress_print_out = sys.stdout
gradio_theme = gr.themes.Base()
total_tqdm = None
total_tqdm: 'shared_total_tqdm.TotalTQDM' = None
mem_mon = None
mem_mon: 'memmon.MemUsageMonitor' = None
options_section = options.options_section
OptionInfo = options.OptionInfo

View file

@ -1,5 +1,8 @@
import html
import sys
from modules import script_callbacks, scripts, ui_components
from modules.options import OptionHTML, OptionInfo
from modules.shared_cmd_options import cmd_opts
@ -118,6 +121,45 @@ def ui_reorder_categories():
yield "scripts"
def callbacks_order_settings():
options = {
"sd_vae_explanation": OptionHTML("""
For categories below, callbacks added to dropdowns happen before others, in order listed.
"""),
}
callback_options = {}
for category, _ in script_callbacks.enumerate_callbacks():
callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
for method_name in scripts.scripts_txt2img.callback_names:
callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
for method_name in scripts.scripts_img2img.callback_names:
callbacks = callback_options.get("script_" + method_name, [])
for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
if any(x.name == addition.name for x in callbacks):
continue
callbacks.append(addition)
callback_options["script_" + method_name] = callbacks
for category, callbacks in callback_options.items():
if not callbacks:
continue
option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
option_info.needs_restart()
option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
options['prioritized_callbacks_' + category] = option_info
return options
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than

View file

@ -101,6 +101,7 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
"DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
"set_scale_by_when_changing_upscaler": OptionInfo(False, "Automatically set the Scale by factor based on the name of the selected Upscaler."),
}))
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
@ -315,6 +316,8 @@ options_templates.update(options_section(('ui', "User interface", "ui"), {
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"enable_reloading_ui_scripts": OptionInfo(False, "Reload UI scripts when using Reload UI option").info("useful for developing: if you make changes to UI scripts code, it is applied when the UI is reloded."),
}))
@ -366,7 +369,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),

View file

@ -1,3 +1,4 @@
from __future__ import annotations
from pathlib import Path
from modules import errors
import csv

View file

@ -11,7 +11,7 @@ from PIL import Image
import gradio as gr
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
override_settings = create_override_settings_dict(override_settings_texts)
if force_enable_hr:
@ -24,10 +24,8 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
prompt=prompt,
styles=prompt_styles,
negative_prompt=negative_prompt,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
@ -40,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,

View file

@ -12,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
from modules import gradio_extensons, sd_schedulers # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
@ -229,19 +229,6 @@ def create_output_panel(tabname, outdir, toprow=None):
return ui_common.create_output_panel(tabname, outdir, toprow)
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
return steps, sampler_name
def ordered_ui_categories():
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
@ -295,9 +282,6 @@ def create_ui():
if category == "prompt":
toprow.create_inline_toprow_prompts()
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
@ -338,10 +322,11 @@ def create_ui():
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
@ -396,8 +381,6 @@ def create_ui():
toprow.prompt,
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
sampler_name,
batch_count,
batch_size,
cfg_scale,
@ -412,6 +395,7 @@ def create_ui():
hr_resize_y,
hr_checkpoint_name,
hr_sampler_name,
hr_scheduler,
hr_prompt,
hr_negative_prompt,
override_settings,
@ -461,8 +445,6 @@ def create_ui():
txt2img_paste_fields = [
PasteField(toprow.prompt, "Prompt", api="prompt"),
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
PasteField(steps, "Steps", api="steps"),
PasteField(sampler_name, "Sampler", api="sampler_name"),
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
PasteField(width, "Size-1", api="width"),
PasteField(height, "Size-2", api="height"),
@ -476,8 +458,9 @@ def create_ui():
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
@ -488,11 +471,13 @@ def create_ui():
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
))
steps = scripts.scripts_txt2img.script('Sampler').steps
txt2img_preview_params = [
toprow.prompt,
toprow.negative_prompt,
steps,
sampler_name,
scripts.scripts_txt2img.script('Sampler').sampler_name,
cfg_scale,
scripts.scripts_txt2img.script('Seed').seed,
width,
@ -623,9 +608,6 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
@ -754,8 +736,6 @@ def create_ui():
inpaint_color_sketch_orig,
init_img_inpaint,
init_mask_inpaint,
steps,
sampler_name,
mask_blur,
mask_alpha,
inpainting_fill,
@ -840,6 +820,8 @@ def create_ui():
**interrogate_args,
)
steps = scripts.scripts_img2img.script('Sampler').steps
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
@ -848,8 +830,6 @@ def create_ui():
img2img_paste_fields = [
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_name, "Sampler"),
(cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
(width, "Size-1"),

View file

@ -1,6 +1,8 @@
import functools
import os.path
import urllib.parse
from base64 import b64decode
from io import BytesIO
from pathlib import Path
from typing import Optional, Union
from dataclasses import dataclass
@ -14,6 +16,7 @@ import gradio as gr
import json
import html
from fastapi.exceptions import HTTPException
from PIL import Image
from modules.infotext_utils import image_from_url_text
@ -114,6 +117,31 @@ def fetch_file(filename: str = ""):
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
def fetch_cover_images(page: str = "", item: str = "", index: int = 0):
from starlette.responses import Response
page = next(iter([x for x in extra_pages if x.name == page]), None)
if page is None:
raise HTTPException(status_code=404, detail="File not found")
metadata = page.metadata.get(item)
if metadata is None:
raise HTTPException(status_code=404, detail="File not found")
cover_images = json.loads(metadata.get('ssmd_cover_images', {}))
image = cover_images[index] if index < len(cover_images) else None
if not image:
raise HTTPException(status_code=404, detail="File not found")
try:
image = Image.open(BytesIO(b64decode(image)))
buffer = BytesIO()
image.save(buffer, format=image.format)
return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())
except Exception as err:
raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err
def get_metadata(page: str = "", item: str = ""):
from starlette.responses import JSONResponse
@ -125,6 +153,8 @@ def get_metadata(page: str = "", item: str = ""):
if metadata is None:
return JSONResponse({})
metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'} # those are cover images, and they are too big to display in UI as text
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
@ -148,6 +178,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
@ -157,6 +188,7 @@ def quote_js(s):
s = s.replace('"', '\\"')
return f'"{s}"'
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
@ -428,14 +460,12 @@ class ExtraNetworksPage:
btn_metadata = self.btn_metadata_tpl.format(
**{
"extra_networks_tabname": self.extra_networks_tabname,
"name": html.escape(item["name"]),
}
)
btn_edit_item = self.btn_edit_item_tpl.format(
**{
"tabname": tabname,
"extra_networks_tabname": self.extra_networks_tabname,
"name": html.escape(item["name"]),
}
)
@ -673,6 +703,17 @@ class ExtraNetworksPage:
return None
def find_embedded_preview(self, path, name, metadata):
"""
Find if embedded preview exists in safetensors metadata and return endpoint for it.
"""
file = f"{path}.safetensors"
if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:
return f"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}"
return None
def find_description(self, path):
"""
Find and read a description file for a given path (without extension).

View file

@ -104,6 +104,8 @@ class UiLoadsave:
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
if type(x) == InputAccordion:
if hasattr(x, 'custom_script_source'):
x.accordion.custom_script_source = x.custom_script_source
if x.accordion.visible:
apply_field(x.accordion, 'visible')
apply_field(x, 'value')

View file

@ -12,7 +12,7 @@ def create_ui():
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image", image_mode="RGBA")
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")

View file

@ -1,7 +1,8 @@
import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.call_queue import wrap_gradio_call
from modules.options import options_section
from modules.shared import opts
from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript
@ -108,6 +109,11 @@ class UiSettings:
shared.settings_components = self.component_dict
# we add this as late as possible so that scripts have already registered their callbacks
opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
**shared_items.callbacks_order_settings(),
}))
opts.reorder()
with gr.Blocks(analytics_enabled=False) as settings_interface:

View file

@ -20,7 +20,7 @@ class Upscaler:
filter = None
model = None
user_path = None
scalers: []
scalers: list
tile = True
def __init__(self, create_dirs=False):

View file

@ -148,8 +148,26 @@ class MassFileLister:
"""Clear the cache of all directories."""
self.cached_dirs.clear()
def update_file_entry(self, path):
"""Update the cache for a specific directory."""
dirname, filename = os.path.split(path)
if cached_dir := self.cached_dirs.get(dirname):
cached_dir.update_entry(filename)
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result