mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-01-11 03:31:38 -08:00
Merge branch 'master' into master
This commit is contained in:
commit
9b2dcb04bc
44 changed files with 967 additions and 969 deletions
|
|
@ -6,8 +6,11 @@ import uvicorn
|
|||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
|
|
@ -18,7 +21,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
|
|
@ -90,6 +93,16 @@ def encode_pil_to_base64(image):
|
|||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
try:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
except:
|
||||
import traceback
|
||||
rich_available = False
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
|
|
@ -110,6 +123,36 @@ def api_middleware(app: FastAPI):
|
|||
))
|
||||
return res
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get('detail', ''),
|
||||
"body": vars(e).get('body', ''),
|
||||
"errors": str(e),
|
||||
}
|
||||
print(f"API error: {request.method}: {request.url} {err}")
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
if rich_available:
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
else:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, e: HTTPException):
|
||||
return handle_exception(request, e)
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
|
|
@ -150,6 +193,8 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
|
|
@ -412,6 +457,16 @@ class Api:
|
|||
|
||||
return {}
|
||||
|
||||
def unloadapi(self):
|
||||
unload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def reloadapi(self):
|
||||
reload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
|
|
|
|||
102
modules/cmd_args.py
Normal file
102
modules/cmd_args.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import argparse
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
|
||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
|
||||
parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
|
||||
parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
|
|
@ -8,11 +8,9 @@ import git
|
|||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
if not os.path.exists(paths.extensions_dir):
|
||||
os.makedirs(paths.extensions_dir)
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
|
@ -86,11 +84,11 @@ class Extension:
|
|||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
if not os.path.isdir(paths.extensions_dir):
|
||||
return
|
||||
|
||||
paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
extension_paths = []
|
||||
for dirname in [paths.extensions_dir, paths.extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
|
||||
|
|
@ -99,9 +97,9 @@ def list_extensions():
|
|||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
extension_paths.append((extension_dirname, path, dirname == paths.extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
for dirname, path, is_builtin in extension_paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
|
|
|
|||
|
|
@ -401,9 +401,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -645,6 +645,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
|
|||
|
||||
|
||||
def image_data(data):
|
||||
import gradio as gr
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
|
|
@ -660,7 +662,7 @@ def image_data(data):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
return '', None
|
||||
return gr.update(), None
|
||||
|
||||
|
||||
def flatten(img, bgcolor):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import platform
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
|
@ -32,6 +33,10 @@ if has_mps:
|
|||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
|
|
@ -49,4 +54,6 @@ if has_mps:
|
|||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import shutil
|
|||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
|
|
@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,9 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
||||
|
||||
import modules.safe
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser.parse_known_args()[0]
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
|
|
|||
22
modules/paths_internal.py
Normal file
22
modules/paths_internal.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
|
||||
models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
|
|
@ -689,6 +689,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
|
||||
|
||||
if opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||
|
||||
if opts.save_mask_composite:
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
||||
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
|
||||
if opts.return_mask_composite:
|
||||
output_images.append(image_mask_composite)
|
||||
|
||||
del x_samples_ddim
|
||||
|
||||
devices.torch_gc()
|
||||
|
|
|
|||
|
|
@ -239,7 +239,15 @@ def load_scripts():
|
|||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in sorted(scripts_list):
|
||||
def orderby(basedir):
|
||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
||||
for key in priority:
|
||||
if basedir.startswith(key):
|
||||
return priority[key]
|
||||
return 9999
|
||||
|
||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
|
|
@ -513,6 +521,18 @@ def reload_scripts():
|
|||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
|
|
@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
|
|||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
|
|||
inputs = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Box() as group:
|
||||
with gr.Row() as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
|
||||
script.group = group
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
|
|
@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def hijack_ddpm_edit():
|
|||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ def select_checkpoint():
|
|||
return checkpoint_info
|
||||
|
||||
|
||||
chckpoint_dict_replacements = {
|
||||
checkpoint_dict_replacements = {
|
||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||
|
|
@ -186,7 +186,7 @@ chckpoint_dict_replacements = {
|
|||
|
||||
|
||||
def transform_checkpoint_dict_key(k):
|
||||
for text, replacement in chckpoint_dict_replacements.items():
|
||||
for text, replacement in checkpoint_dict_replacements.items():
|
||||
if k.startswith(text):
|
||||
k = replacement + k[len(text):]
|
||||
|
||||
|
|
@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return shared.sd_model
|
||||
|
||||
try:
|
||||
|
|
@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if shared.sd_model:
|
||||
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
shared.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
|
@ -13,114 +13,22 @@ import modules.interrogate
|
|||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path, data_path
|
||||
|
||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||
|
||||
demo = None
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = cmd_args.parser
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
script_loading.preload_extensions(extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||
|
||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||
cmd_opts = parser.parse_args()
|
||||
else:
|
||||
cmd_opts, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
|
|
@ -332,6 +240,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||
|
|
@ -448,12 +358,16 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
|
|
|
|||
|
|
@ -152,7 +152,11 @@ class EmbeddingDatabase:
|
|||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
if data:
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
# if data is None, means this is not an embeding, just a preview image
|
||||
return
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
|
|||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
|
@ -89,7 +89,7 @@ paste_symbol = '\u2199\ufe0f' # ↙
|
|||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
apply_style_symbol = '\U0001f4cb' # 📋
|
||||
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
||||
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
switch_values_symbol = '\U000021C5' # ⇅
|
||||
|
||||
|
|
@ -179,14 +179,13 @@ def interrogate_deepbooru(image):
|
|||
|
||||
|
||||
def create_seed_inputs(target_interface):
|
||||
with FormRow(elem_id=target_interface + '_seed_row'):
|
||||
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
||||
seed.style(container=False)
|
||||
random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
|
||||
reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
||||
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
|
||||
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
||||
|
||||
with gr.Group(elem_id=target_interface + '_subseed_show_box'):
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
|
||||
# Components to show/hide based on the 'Extra' checkbox
|
||||
seed_extras = []
|
||||
|
|
@ -195,8 +194,8 @@ def create_seed_inputs(target_interface):
|
|||
seed_extras.append(seed_extra_row_1)
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
|
||||
subseed.style(container=False)
|
||||
random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
|
||||
|
||||
with FormRow(visible=False) as seed_extra_row_2:
|
||||
|
|
@ -291,19 +290,19 @@ def create_toprow(is_img2img):
|
|||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
skip.click(
|
||||
|
|
@ -325,9 +324,9 @@ def create_toprow(is_img2img):
|
|||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
||||
|
||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
|
||||
negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||
|
||||
clear_prompt_button.click(
|
||||
|
|
@ -479,7 +478,9 @@ def create_ui():
|
|||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||
|
|
@ -492,7 +493,7 @@ def create_ui():
|
|||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
|
||||
|
||||
elif category == "checkboxes":
|
||||
with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
|
||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
|
||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
||||
|
|
@ -586,7 +587,7 @@ def create_ui():
|
|||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
|
|
@ -757,7 +758,9 @@ def create_ui():
|
|||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="img2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||
|
|
@ -774,7 +777,7 @@ def create_ui():
|
|||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
|
||||
|
||||
elif category == "checkboxes":
|
||||
with FormRow(elem_id="img2img_checkboxes", variant="compact"):
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
|
||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
|
||||
|
||||
|
|
@ -904,7 +907,7 @@ def create_ui():
|
|||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
|
|
@ -1491,11 +1494,33 @@ def create_ui():
|
|||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
with gr.Row():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
|
||||
with gr.TabItem("Licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
|
||||
def unload_sd_weights():
|
||||
modules.sd_models.unload_model_weights()
|
||||
|
||||
def reload_sd_weights():
|
||||
modules.sd_models.reload_model_weights()
|
||||
|
||||
unload_sd_model.click(
|
||||
fn=unload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
reload_sd_model.click(
|
||||
fn=reload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
request_notifications.click(
|
||||
fn=lambda: None,
|
||||
|
|
@ -1598,11 +1623,13 @@ def create_ui():
|
|||
|
||||
for i, k, item in quicksettings_list:
|
||||
component = component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
component.change(
|
||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
|
||||
text_settings.change(
|
||||
|
|
|
|||
|
|
@ -129,8 +129,8 @@ Requested path was: {f}
|
|||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
|
||||
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
|
|
@ -149,7 +149,7 @@ Requested path was: {f}
|
|||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
|
|
@ -160,6 +160,7 @@ Requested path was: {f}
|
|||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||
inputs=[generation_info, html_info, html_info],
|
||||
outputs=[html_info, html_info],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
save.click(
|
||||
|
|
@ -195,7 +196,7 @@ Requested path was: {f}
|
|||
|
||||
else:
|
||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
paste_field_names = []
|
||||
|
|
|
|||
|
|
@ -1,55 +1,61 @@
|
|||
import gradio as gr
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
class FormComponent:
|
||||
def get_expected_parent(self):
|
||||
return gr.components.Form
|
||||
|
||||
|
||||
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
|
||||
|
||||
|
||||
class ToolButton(FormComponent, gr.Button):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool", **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
classes = kwargs.pop("elem_classes", [])
|
||||
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool-top", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
class FormRow(FormComponent, gr.Row):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "row"
|
||||
|
||||
|
||||
class FormGroup(gr.Group, gr.components.FormComponent):
|
||||
class FormColumn(FormComponent, gr.Column):
|
||||
"""Same as gr.Column but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "column"
|
||||
|
||||
|
||||
class FormGroup(FormComponent, gr.Group):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "group"
|
||||
|
||||
|
||||
class FormHTML(gr.HTML, gr.components.FormComponent):
|
||||
class FormHTML(FormComponent, gr.HTML):
|
||||
"""Same as gr.HTML but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "html"
|
||||
|
||||
|
||||
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
||||
class FormColorPicker(FormComponent, gr.ColorPicker):
|
||||
"""Same as gr.ColorPicker but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
class DropdownMulti(FormComponent, gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
|
@ -141,22 +140,20 @@ def install_extension_from_url(dirname, url):
|
|||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
repo = git.Repo.clone_from(url, tmpdir)
|
||||
repo.remote().fetch()
|
||||
|
||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
try:
|
||||
os.rename(tmpdir, target_dir)
|
||||
except OSError as err:
|
||||
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
|
||||
# Shouldn't cause any new issues at least but we probably want to handle it there too.
|
||||
if err.errno == errno.EXDEV:
|
||||
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
||||
# Since we can't use a rename, do the slower but more versitile shutil.move()
|
||||
shutil.move(tmpdir, target_dir)
|
||||
else:
|
||||
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
||||
raise(err)
|
||||
raise err
|
||||
|
||||
import launch
|
||||
launch.run_extension_installer(target_dir)
|
||||
|
|
@ -255,7 +252,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||
hidden += 1
|
||||
continue
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
|
||||
|
||||
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
||||
|
||||
|
|
|
|||
|
|
@ -22,21 +22,37 @@ def register_page(page):
|
|||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||
|
||||
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
|
||||
def get_metadata(page: str = "", item: str = ""):
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
||||
if page is None:
|
||||
return JSONResponse({})
|
||||
|
||||
metadata = page.metadata.get(item)
|
||||
if metadata is None:
|
||||
return JSONResponse({})
|
||||
|
||||
return JSONResponse({"metadata": metadata})
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
|
|
@ -45,6 +61,7 @@ class ExtraNetworksPage:
|
|||
self.name = title.lower()
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
self.metadata = {}
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
|
@ -66,6 +83,8 @@ class ExtraNetworksPage:
|
|||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
|
|
@ -86,12 +105,16 @@ class ExtraNetworksPage:
|
|||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
self.metadata[item["name"]] = metadata
|
||||
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
|
|
@ -124,14 +147,16 @@ class ExtraNetworksPage:
|
|||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
|
||||
metadata_button = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||
"style": f"'{height}{width}{background_image}'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
|
|
@ -215,6 +240,7 @@ def create_ui(container, button, tabname):
|
|||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue