Merge branch 'master' into racecond_fix

This commit is contained in:
AUTOMATIC1111 2022-12-03 10:19:51 +03:00 committed by GitHub
commit c9a2cfdf2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
93 changed files with 3589 additions and 8465 deletions

View file

@ -2,14 +2,22 @@ import base64
import io
import time
import uvicorn
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
def upscaler_to_index(name: str):
try:
@ -18,8 +26,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
return name
def setUpscalers(req: dict):
reqDict = vars(req)
@ -29,39 +41,84 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(BytesIO(base64.b64decode(encoding)))
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
with io.BytesIO() as output_bytes:
# Copy any text-only metadata
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
class Api:
def __init__(self, app, queue_lock):
def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth:
self.credenticals = dict()
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
self.credenticals[user] = password
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)
def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
if credenticals.username in self.credenticals:
if compare_digest(credenticals.password, self.credenticals[credenticals.username]):
return True
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
)
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
@ -77,12 +134,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@ -91,16 +142,20 @@ class Api:
if mask:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
p = StableDiffusionProcessingImg2Img(**args)
imgs = []
for img in init_images:
@ -118,7 +173,7 @@ class Api:
b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
if not img2imgreq.include_init_images:
img2imgreq.init_images = None
img2imgreq.mask = None
@ -186,11 +241,92 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interrogateapi(self, interrogatereq: InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64)
img = img.convert('RGB')
# Override object param
with self.queue_lock:
if interrogatereq.model == "clip":
processed = shared.interrogator.interrogate(img)
elif interrogatereq.model == "deepdanbooru":
processed = deepbooru.model.tag(img)
else:
raise HTTPException(status_code=404, detail="Model not found")
return InterrogateResponse(caption=processed)
def interruptapi(self):
shared.state.interrupt()
return {}
def skip(self):
shared.state.skip()
def get_config(self):
options = {}
for key in shared.opts.data.keys():
metadata = shared.opts.data_labels.get(key)
if(metadata is not None):
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
else:
options.update({key: shared.opts.data.get(key, None)})
return options
def set_config(self, req: Dict[str, Any]):
for k, v in req.items():
shared.opts.set(k, v)
shared.opts.save(shared.config_filename)
return
def get_cmd_flags(self):
return vars(shared.cmd_opts)
def get_samplers(self):
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
upscalers = []
for upscaler in shared.sd_upscalers:
u = upscaler.scaler
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
return upscalers
def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
def get_face_restorers(self):
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
def get_promp_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)

View file

@ -1,11 +1,11 @@
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@ -65,6 +65,7 @@ class PydanticModelGenerator:
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
field=underscore(k),
@ -109,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@ -147,10 +148,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
@ -166,3 +167,76 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
optType = opts.typemap.get(type(metadata.default), type(value))
if (metadata is not None):
fields.update({key: (Optional[optType], Field(
default=metadata.default ,description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
OptionsModel = create_model("Options", **fields)
flags = {}
_options = vars(parser)['_option_string_actions']
for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
_type = str
if _options[key].default is not None: _type = type(_options[key].default)
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
hash: str = Field(title="Hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
class FaceRestorerItem(BaseModel):
name: str = Field(title="Name")
cmd_dir: Optional[str] = Field(title="Path")
class RealesrganItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
scale: Optional[int] = Field(title="Scale")
class PromptStyleItem(BaseModel):
name: str = Field(title="Name")
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")

98
modules/call_queue.py Normal file
View file

@ -0,0 +1,98 @@
import html
import sys
import threading
import traceback
import time
from modules import shared
queue_lock = threading.Lock()
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
shared.state.begin()
with queue_lock:
res = func(*args, **kwargs)
shared.state.end()
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
return tuple(res)
return f

View file

@ -36,6 +36,7 @@ def setup_model(dirname):
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts
net_class = CodeFormer
@ -65,6 +66,8 @@ def setup_model(dirname):
net.load_state_dict(checkpoint)
net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net

View file

@ -1,173 +1,97 @@
import os.path
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import time
import os
import re
import torch
from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
try:
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
return get_tags_from_process(pil_image)
finally:
release_process()
class DeepDanbooru:
def __init__(self):
self.model = None
def load(self):
if self.model is not None:
return
OPT_INCLUDE_RANKS = "include_ranks"
def create_deepbooru_opts():
from modules import shared
files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
ext_filter=".pt",
download_name='model-resnet_custom_v3.pt',
)
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
}
self.model = deepbooru_model.DeepDanbooruModel()
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval()
self.model.to(devices.cpu, devices.dtype)
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def start(self):
self.load()
self.model.to(devices.device)
def stop(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model.to(devices.cpu)
devices.torch_gc()
def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
context = multiprocessing.get_context("spawn")
shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
def tag(self, pil_image):
self.start()
res = self.tag_multi(pil_image)
self.stop()
return res
def get_tags_from_process(image):
from modules import shared
def tag_multi(self, pil_image, force_disable_ranks=False):
threshold = shared.opts.interrogate_deepbooru_score_threshold
use_spaces = shared.opts.deepbooru_use_spaces
use_escape = shared.opts.deepbooru_escape
alpha_sort = shared.opts.deepbooru_sort_alpha
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
return caption
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
for tag, probability in zip(self.model.tags, y):
if probability < threshold:
continue
def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
model_path, compile_model=False
)
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
image = tf.image.resize(
image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True,
)
image = image.numpy() # EagerTensor to np.array
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.0
image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
y = model.predict(image)[0]
result_dict = {}
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
# sort tags
result_tags_out = []
sort_ndx = 0
if alpha_sort:
sort_ndx = 1
probability_dict[tag] = probability
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
tag_outformat = tag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{weight:.3f})"
if alpha_sort:
tags = sorted(probability_dict)
else:
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
result_tags_out.append(tag_outformat)
res = []
print('\n'.join(sorted(result_tags_print, reverse=True)))
for tag in tags:
probability = probability_dict[tag]
tag_outformat = tag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{probability:.3f})"
return ', '.join(result_tags_out)
res.append(tag_outformat)
return ", ".join(res)
model = DeepDanbooru()

676
modules/deepbooru_model.py Normal file
View file

@ -0,0 +1,676 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
class DeepDanbooruModel(nn.Module):
def __init__(self):
super(DeepDanbooruModel, self).__init__()
self.tags = []
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
def forward(self, *inputs):
t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
t_360 = self.n_Conv_0(t_359_padded)
t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361)
t_363 = self.n_Conv_1(t_362)
t_364 = self.n_Conv_2(t_362)
t_365 = F.relu(t_364)
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
t_366 = self.n_Conv_3(t_365_padded)
t_367 = F.relu(t_366)
t_368 = self.n_Conv_4(t_367)
t_369 = torch.add(t_368, t_363)
t_370 = F.relu(t_369)
t_371 = self.n_Conv_5(t_370)
t_372 = F.relu(t_371)
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
t_373 = self.n_Conv_6(t_372_padded)
t_374 = F.relu(t_373)
t_375 = self.n_Conv_7(t_374)
t_376 = torch.add(t_375, t_370)
t_377 = F.relu(t_376)
t_378 = self.n_Conv_8(t_377)
t_379 = F.relu(t_378)
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
t_380 = self.n_Conv_9(t_379_padded)
t_381 = F.relu(t_380)
t_382 = self.n_Conv_10(t_381)
t_383 = torch.add(t_382, t_377)
t_384 = F.relu(t_383)
t_385 = self.n_Conv_11(t_384)
t_386 = self.n_Conv_12(t_384)
t_387 = F.relu(t_386)
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
t_388 = self.n_Conv_13(t_387_padded)
t_389 = F.relu(t_388)
t_390 = self.n_Conv_14(t_389)
t_391 = torch.add(t_390, t_385)
t_392 = F.relu(t_391)
t_393 = self.n_Conv_15(t_392)
t_394 = F.relu(t_393)
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
t_395 = self.n_Conv_16(t_394_padded)
t_396 = F.relu(t_395)
t_397 = self.n_Conv_17(t_396)
t_398 = torch.add(t_397, t_392)
t_399 = F.relu(t_398)
t_400 = self.n_Conv_18(t_399)
t_401 = F.relu(t_400)
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
t_402 = self.n_Conv_19(t_401_padded)
t_403 = F.relu(t_402)
t_404 = self.n_Conv_20(t_403)
t_405 = torch.add(t_404, t_399)
t_406 = F.relu(t_405)
t_407 = self.n_Conv_21(t_406)
t_408 = F.relu(t_407)
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
t_409 = self.n_Conv_22(t_408_padded)
t_410 = F.relu(t_409)
t_411 = self.n_Conv_23(t_410)
t_412 = torch.add(t_411, t_406)
t_413 = F.relu(t_412)
t_414 = self.n_Conv_24(t_413)
t_415 = F.relu(t_414)
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
t_416 = self.n_Conv_25(t_415_padded)
t_417 = F.relu(t_416)
t_418 = self.n_Conv_26(t_417)
t_419 = torch.add(t_418, t_413)
t_420 = F.relu(t_419)
t_421 = self.n_Conv_27(t_420)
t_422 = F.relu(t_421)
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
t_423 = self.n_Conv_28(t_422_padded)
t_424 = F.relu(t_423)
t_425 = self.n_Conv_29(t_424)
t_426 = torch.add(t_425, t_420)
t_427 = F.relu(t_426)
t_428 = self.n_Conv_30(t_427)
t_429 = F.relu(t_428)
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
t_430 = self.n_Conv_31(t_429_padded)
t_431 = F.relu(t_430)
t_432 = self.n_Conv_32(t_431)
t_433 = torch.add(t_432, t_427)
t_434 = F.relu(t_433)
t_435 = self.n_Conv_33(t_434)
t_436 = F.relu(t_435)
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
t_437 = self.n_Conv_34(t_436_padded)
t_438 = F.relu(t_437)
t_439 = self.n_Conv_35(t_438)
t_440 = torch.add(t_439, t_434)
t_441 = F.relu(t_440)
t_442 = self.n_Conv_36(t_441)
t_443 = self.n_Conv_37(t_441)
t_444 = F.relu(t_443)
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
t_445 = self.n_Conv_38(t_444_padded)
t_446 = F.relu(t_445)
t_447 = self.n_Conv_39(t_446)
t_448 = torch.add(t_447, t_442)
t_449 = F.relu(t_448)
t_450 = self.n_Conv_40(t_449)
t_451 = F.relu(t_450)
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
t_452 = self.n_Conv_41(t_451_padded)
t_453 = F.relu(t_452)
t_454 = self.n_Conv_42(t_453)
t_455 = torch.add(t_454, t_449)
t_456 = F.relu(t_455)
t_457 = self.n_Conv_43(t_456)
t_458 = F.relu(t_457)
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
t_459 = self.n_Conv_44(t_458_padded)
t_460 = F.relu(t_459)
t_461 = self.n_Conv_45(t_460)
t_462 = torch.add(t_461, t_456)
t_463 = F.relu(t_462)
t_464 = self.n_Conv_46(t_463)
t_465 = F.relu(t_464)
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
t_466 = self.n_Conv_47(t_465_padded)
t_467 = F.relu(t_466)
t_468 = self.n_Conv_48(t_467)
t_469 = torch.add(t_468, t_463)
t_470 = F.relu(t_469)
t_471 = self.n_Conv_49(t_470)
t_472 = F.relu(t_471)
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
t_473 = self.n_Conv_50(t_472_padded)
t_474 = F.relu(t_473)
t_475 = self.n_Conv_51(t_474)
t_476 = torch.add(t_475, t_470)
t_477 = F.relu(t_476)
t_478 = self.n_Conv_52(t_477)
t_479 = F.relu(t_478)
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
t_480 = self.n_Conv_53(t_479_padded)
t_481 = F.relu(t_480)
t_482 = self.n_Conv_54(t_481)
t_483 = torch.add(t_482, t_477)
t_484 = F.relu(t_483)
t_485 = self.n_Conv_55(t_484)
t_486 = F.relu(t_485)
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
t_487 = self.n_Conv_56(t_486_padded)
t_488 = F.relu(t_487)
t_489 = self.n_Conv_57(t_488)
t_490 = torch.add(t_489, t_484)
t_491 = F.relu(t_490)
t_492 = self.n_Conv_58(t_491)
t_493 = F.relu(t_492)
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
t_494 = self.n_Conv_59(t_493_padded)
t_495 = F.relu(t_494)
t_496 = self.n_Conv_60(t_495)
t_497 = torch.add(t_496, t_491)
t_498 = F.relu(t_497)
t_499 = self.n_Conv_61(t_498)
t_500 = F.relu(t_499)
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
t_501 = self.n_Conv_62(t_500_padded)
t_502 = F.relu(t_501)
t_503 = self.n_Conv_63(t_502)
t_504 = torch.add(t_503, t_498)
t_505 = F.relu(t_504)
t_506 = self.n_Conv_64(t_505)
t_507 = F.relu(t_506)
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
t_508 = self.n_Conv_65(t_507_padded)
t_509 = F.relu(t_508)
t_510 = self.n_Conv_66(t_509)
t_511 = torch.add(t_510, t_505)
t_512 = F.relu(t_511)
t_513 = self.n_Conv_67(t_512)
t_514 = F.relu(t_513)
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
t_515 = self.n_Conv_68(t_514_padded)
t_516 = F.relu(t_515)
t_517 = self.n_Conv_69(t_516)
t_518 = torch.add(t_517, t_512)
t_519 = F.relu(t_518)
t_520 = self.n_Conv_70(t_519)
t_521 = F.relu(t_520)
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
t_522 = self.n_Conv_71(t_521_padded)
t_523 = F.relu(t_522)
t_524 = self.n_Conv_72(t_523)
t_525 = torch.add(t_524, t_519)
t_526 = F.relu(t_525)
t_527 = self.n_Conv_73(t_526)
t_528 = F.relu(t_527)
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
t_529 = self.n_Conv_74(t_528_padded)
t_530 = F.relu(t_529)
t_531 = self.n_Conv_75(t_530)
t_532 = torch.add(t_531, t_526)
t_533 = F.relu(t_532)
t_534 = self.n_Conv_76(t_533)
t_535 = F.relu(t_534)
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
t_536 = self.n_Conv_77(t_535_padded)
t_537 = F.relu(t_536)
t_538 = self.n_Conv_78(t_537)
t_539 = torch.add(t_538, t_533)
t_540 = F.relu(t_539)
t_541 = self.n_Conv_79(t_540)
t_542 = F.relu(t_541)
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
t_543 = self.n_Conv_80(t_542_padded)
t_544 = F.relu(t_543)
t_545 = self.n_Conv_81(t_544)
t_546 = torch.add(t_545, t_540)
t_547 = F.relu(t_546)
t_548 = self.n_Conv_82(t_547)
t_549 = F.relu(t_548)
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
t_550 = self.n_Conv_83(t_549_padded)
t_551 = F.relu(t_550)
t_552 = self.n_Conv_84(t_551)
t_553 = torch.add(t_552, t_547)
t_554 = F.relu(t_553)
t_555 = self.n_Conv_85(t_554)
t_556 = F.relu(t_555)
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
t_557 = self.n_Conv_86(t_556_padded)
t_558 = F.relu(t_557)
t_559 = self.n_Conv_87(t_558)
t_560 = torch.add(t_559, t_554)
t_561 = F.relu(t_560)
t_562 = self.n_Conv_88(t_561)
t_563 = F.relu(t_562)
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
t_564 = self.n_Conv_89(t_563_padded)
t_565 = F.relu(t_564)
t_566 = self.n_Conv_90(t_565)
t_567 = torch.add(t_566, t_561)
t_568 = F.relu(t_567)
t_569 = self.n_Conv_91(t_568)
t_570 = F.relu(t_569)
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
t_571 = self.n_Conv_92(t_570_padded)
t_572 = F.relu(t_571)
t_573 = self.n_Conv_93(t_572)
t_574 = torch.add(t_573, t_568)
t_575 = F.relu(t_574)
t_576 = self.n_Conv_94(t_575)
t_577 = F.relu(t_576)
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
t_578 = self.n_Conv_95(t_577_padded)
t_579 = F.relu(t_578)
t_580 = self.n_Conv_96(t_579)
t_581 = torch.add(t_580, t_575)
t_582 = F.relu(t_581)
t_583 = self.n_Conv_97(t_582)
t_584 = F.relu(t_583)
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
t_585 = self.n_Conv_98(t_584_padded)
t_586 = F.relu(t_585)
t_587 = self.n_Conv_99(t_586)
t_588 = self.n_Conv_100(t_582)
t_589 = torch.add(t_587, t_588)
t_590 = F.relu(t_589)
t_591 = self.n_Conv_101(t_590)
t_592 = F.relu(t_591)
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
t_593 = self.n_Conv_102(t_592_padded)
t_594 = F.relu(t_593)
t_595 = self.n_Conv_103(t_594)
t_596 = torch.add(t_595, t_590)
t_597 = F.relu(t_596)
t_598 = self.n_Conv_104(t_597)
t_599 = F.relu(t_598)
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
t_600 = self.n_Conv_105(t_599_padded)
t_601 = F.relu(t_600)
t_602 = self.n_Conv_106(t_601)
t_603 = torch.add(t_602, t_597)
t_604 = F.relu(t_603)
t_605 = self.n_Conv_107(t_604)
t_606 = F.relu(t_605)
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
t_607 = self.n_Conv_108(t_606_padded)
t_608 = F.relu(t_607)
t_609 = self.n_Conv_109(t_608)
t_610 = torch.add(t_609, t_604)
t_611 = F.relu(t_610)
t_612 = self.n_Conv_110(t_611)
t_613 = F.relu(t_612)
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
t_614 = self.n_Conv_111(t_613_padded)
t_615 = F.relu(t_614)
t_616 = self.n_Conv_112(t_615)
t_617 = torch.add(t_616, t_611)
t_618 = F.relu(t_617)
t_619 = self.n_Conv_113(t_618)
t_620 = F.relu(t_619)
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
t_621 = self.n_Conv_114(t_620_padded)
t_622 = F.relu(t_621)
t_623 = self.n_Conv_115(t_622)
t_624 = torch.add(t_623, t_618)
t_625 = F.relu(t_624)
t_626 = self.n_Conv_116(t_625)
t_627 = F.relu(t_626)
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
t_628 = self.n_Conv_117(t_627_padded)
t_629 = F.relu(t_628)
t_630 = self.n_Conv_118(t_629)
t_631 = torch.add(t_630, t_625)
t_632 = F.relu(t_631)
t_633 = self.n_Conv_119(t_632)
t_634 = F.relu(t_633)
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
t_635 = self.n_Conv_120(t_634_padded)
t_636 = F.relu(t_635)
t_637 = self.n_Conv_121(t_636)
t_638 = torch.add(t_637, t_632)
t_639 = F.relu(t_638)
t_640 = self.n_Conv_122(t_639)
t_641 = F.relu(t_640)
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
t_642 = self.n_Conv_123(t_641_padded)
t_643 = F.relu(t_642)
t_644 = self.n_Conv_124(t_643)
t_645 = torch.add(t_644, t_639)
t_646 = F.relu(t_645)
t_647 = self.n_Conv_125(t_646)
t_648 = F.relu(t_647)
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
t_649 = self.n_Conv_126(t_648_padded)
t_650 = F.relu(t_649)
t_651 = self.n_Conv_127(t_650)
t_652 = torch.add(t_651, t_646)
t_653 = F.relu(t_652)
t_654 = self.n_Conv_128(t_653)
t_655 = F.relu(t_654)
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
t_656 = self.n_Conv_129(t_655_padded)
t_657 = F.relu(t_656)
t_658 = self.n_Conv_130(t_657)
t_659 = torch.add(t_658, t_653)
t_660 = F.relu(t_659)
t_661 = self.n_Conv_131(t_660)
t_662 = F.relu(t_661)
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
t_663 = self.n_Conv_132(t_662_padded)
t_664 = F.relu(t_663)
t_665 = self.n_Conv_133(t_664)
t_666 = torch.add(t_665, t_660)
t_667 = F.relu(t_666)
t_668 = self.n_Conv_134(t_667)
t_669 = F.relu(t_668)
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
t_670 = self.n_Conv_135(t_669_padded)
t_671 = F.relu(t_670)
t_672 = self.n_Conv_136(t_671)
t_673 = torch.add(t_672, t_667)
t_674 = F.relu(t_673)
t_675 = self.n_Conv_137(t_674)
t_676 = F.relu(t_675)
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
t_677 = self.n_Conv_138(t_676_padded)
t_678 = F.relu(t_677)
t_679 = self.n_Conv_139(t_678)
t_680 = torch.add(t_679, t_674)
t_681 = F.relu(t_680)
t_682 = self.n_Conv_140(t_681)
t_683 = F.relu(t_682)
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
t_684 = self.n_Conv_141(t_683_padded)
t_685 = F.relu(t_684)
t_686 = self.n_Conv_142(t_685)
t_687 = torch.add(t_686, t_681)
t_688 = F.relu(t_687)
t_689 = self.n_Conv_143(t_688)
t_690 = F.relu(t_689)
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
t_691 = self.n_Conv_144(t_690_padded)
t_692 = F.relu(t_691)
t_693 = self.n_Conv_145(t_692)
t_694 = torch.add(t_693, t_688)
t_695 = F.relu(t_694)
t_696 = self.n_Conv_146(t_695)
t_697 = F.relu(t_696)
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
t_698 = self.n_Conv_147(t_697_padded)
t_699 = F.relu(t_698)
t_700 = self.n_Conv_148(t_699)
t_701 = torch.add(t_700, t_695)
t_702 = F.relu(t_701)
t_703 = self.n_Conv_149(t_702)
t_704 = F.relu(t_703)
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
t_705 = self.n_Conv_150(t_704_padded)
t_706 = F.relu(t_705)
t_707 = self.n_Conv_151(t_706)
t_708 = torch.add(t_707, t_702)
t_709 = F.relu(t_708)
t_710 = self.n_Conv_152(t_709)
t_711 = F.relu(t_710)
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
t_712 = self.n_Conv_153(t_711_padded)
t_713 = F.relu(t_712)
t_714 = self.n_Conv_154(t_713)
t_715 = torch.add(t_714, t_709)
t_716 = F.relu(t_715)
t_717 = self.n_Conv_155(t_716)
t_718 = F.relu(t_717)
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
t_719 = self.n_Conv_156(t_718_padded)
t_720 = F.relu(t_719)
t_721 = self.n_Conv_157(t_720)
t_722 = torch.add(t_721, t_716)
t_723 = F.relu(t_722)
t_724 = self.n_Conv_158(t_723)
t_725 = self.n_Conv_159(t_723)
t_726 = F.relu(t_725)
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
t_727 = self.n_Conv_160(t_726_padded)
t_728 = F.relu(t_727)
t_729 = self.n_Conv_161(t_728)
t_730 = torch.add(t_729, t_724)
t_731 = F.relu(t_730)
t_732 = self.n_Conv_162(t_731)
t_733 = F.relu(t_732)
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
t_734 = self.n_Conv_163(t_733_padded)
t_735 = F.relu(t_734)
t_736 = self.n_Conv_164(t_735)
t_737 = torch.add(t_736, t_731)
t_738 = F.relu(t_737)
t_739 = self.n_Conv_165(t_738)
t_740 = F.relu(t_739)
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
t_741 = self.n_Conv_166(t_740_padded)
t_742 = F.relu(t_741)
t_743 = self.n_Conv_167(t_742)
t_744 = torch.add(t_743, t_738)
t_745 = F.relu(t_744)
t_746 = self.n_Conv_168(t_745)
t_747 = self.n_Conv_169(t_745)
t_748 = F.relu(t_747)
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
t_749 = self.n_Conv_170(t_748_padded)
t_750 = F.relu(t_749)
t_751 = self.n_Conv_171(t_750)
t_752 = torch.add(t_751, t_746)
t_753 = F.relu(t_752)
t_754 = self.n_Conv_172(t_753)
t_755 = F.relu(t_754)
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
t_756 = self.n_Conv_173(t_755_padded)
t_757 = F.relu(t_756)
t_758 = self.n_Conv_174(t_757)
t_759 = torch.add(t_758, t_753)
t_760 = F.relu(t_759)
t_761 = self.n_Conv_175(t_760)
t_762 = F.relu(t_761)
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
t_763 = self.n_Conv_176(t_762_padded)
t_764 = F.relu(t_763)
t_765 = self.n_Conv_177(t_764)
t_766 = torch.add(t_765, t_760)
t_767 = F.relu(t_766)
t_768 = self.n_Conv_178(t_767)
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
t_770 = torch.squeeze(t_769, 3)
t_770 = torch.squeeze(t_770, 2)
t_771 = torch.sigmoid(t_770)
return t_771
def load_state_dict(self, state_dict, **kwargs):
self.tags = state_dict.get('tags', [])
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})

View file

@ -2,30 +2,43 @@ import sys, os, shlex
import contextlib
import torch
from modules import errors
from packaging import version
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def has_mps() -> bool:
if not getattr(torch, 'has_mps', False):
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]: return args[x+1]
if name in args[x]:
return args[x + 1]
return None
def get_cuda_device_string():
from modules import shared
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
return "cuda"
def get_optimal_device():
if torch.cuda.is_available():
from modules import shared
return torch.device(get_cuda_device_string())
device_id = shared.cmd_opts.device_id
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
else:
return torch.device("cuda")
if has_mps:
if has_mps():
return torch.device("mps")
return cpu
@ -33,8 +46,9 @@ def get_optimal_device():
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def enable_tf32():
@ -45,29 +59,22 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
generator.manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
def randn(seed, shape):
torch.manual_seed(seed)
if device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
@ -82,6 +89,27 @@ def autocast(disable=False):
return torch.autocast("cuda")
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
if self.device.type != 'mps' and \
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
self = self.contiguous()
return orig_tensor_to(self, *args, **kwargs)
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
orig_layer_norm = torch.nn.functional.layer_norm
def layer_norm_fix(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
args = list(args)
args[0] = args[0].contiguous()
return orig_layer_norm(*args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix

View file

@ -199,7 +199,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

View file

@ -6,7 +6,6 @@ import git
from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
@ -34,8 +33,11 @@ class Extension:
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
except Exception:
self.remote = None
def list_files(self, subdir, extension):
from modules import scripts
@ -63,9 +65,12 @@ class Extension:
self.can_update = False
self.status = "latest"
def pull(self):
def fetch_and_reset_hard(self):
repo = git.Repo(self.path)
repo.remotes.origin.pull()
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all')
repo.git.reset('--hard', 'origin')
def list_extensions():
@ -81,3 +86,4 @@ def list_extensions():
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import math
import os
import sys
import traceback
import numpy as np
from PIL import Image
@ -12,7 +14,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models
from modules import processing, shared, images, devices, sd_models, sd_samplers
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@ -20,7 +22,7 @@ import modules.codeformer_model
import piexif
import piexif.helper
import gradio as gr
import safetensors.torch
class LruCache(OrderedDict):
@dataclass(frozen=True)
@ -136,12 +138,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info),
args_hash=hash((upscale_args, upscale_first)))
args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
@ -212,25 +215,8 @@ def run_pnginfo(image):
if image is None:
return '', '', ''
items = image.info
geninfo = ''
if "exif" in image.info:
exif = piexif.load(image.info["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
items['exif comment'] = exif_comment
geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
items.pop(field, None)
geninfo = items.get('parameters', geninfo)
geninfo, items = images.read_info_from_image(image)
items = {**{'parameters': geninfo}, **items}
info = ''
for key, text in items.items():
@ -248,7 +234,7 @@ def run_pnginfo(image):
return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@ -263,19 +249,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if teritary_model_info is not None:
print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu')
else:
teritary_model = None
theta_2 = None
theta_funcs = {
@ -294,7 +276,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model
del theta_2
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
@ -313,12 +295,17 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()

View file

@ -2,6 +2,8 @@ import base64
import io
import os
import re
from pathlib import Path
import gradio as gr
from modules.shared import script_path
from modules import shared
@ -35,9 +37,8 @@ def quote(text):
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir())
normfn = os.path.normpath(filename)
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename)
@ -73,7 +74,9 @@ def integrate_settings_paste_fields(component_dict):
'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
'eta_noise_seed_delta': 'ENSD',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
@ -181,6 +184,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
return res

View file

@ -36,7 +36,9 @@ def gfpgann():
else:
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model
return model

View file

@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@ -22,6 +22,8 @@ from collections import defaultdict, deque
from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@ -36,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@ -142,6 +144,8 @@ class Hypernetwork:
self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
self.optimizer_name = None
self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
@ -150,19 +154,32 @@ class Hypernetwork:
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.eval_mode()
def weights(self):
res = []
for k, layers in self.layers.items():
for layer in layers:
res += layer.parameters()
return res
def train_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
res += layer.trainables()
for param in layer.parameters():
param.requires_grad = True
return res
def eval_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def save(self, filename):
state_dict = {}
optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@ -178,8 +195,15 @@ class Hypernetwork:
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
@ -202,6 +226,18 @@ class Hypernetwork:
print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
@ -219,11 +255,11 @@ class Hypernetwork:
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
@ -343,13 +379,13 @@ def report_statistics(loss_info:dict):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@ -358,6 +394,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@ -378,17 +415,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step >= steps:
initial_step = hypernetwork.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
old_parallel_processing_allowed = shared.parallel_processing_allowed
@ -396,19 +439,41 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
hypernetwork.train_mode()
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0
@ -416,124 +481,145 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
if len(loss_dict) > 0:
previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses)
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
if shared.state.interrupted:
break
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
del c
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
del c
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
losses[hypernetwork.step % losses.shape[0]] = loss.item()
for entry in entries:
loss_dict[entry.filename].append(loss.item())
steps_done = hypernetwork.step + 1
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
if weights[0].grad is None:
steps_without_grad += 1
else:
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
optimizer.step()
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
steps_done = hypernetwork.step + 1
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
hypernetwork.eval_mode()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
preview_text = p.prompt
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork.train_mode()
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
shared.state.job_no = hypernetwork.step
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
shared.state.textinfo = f"""
<p>
Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
report_statistics(loss_dict)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval_mode()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename

View file

@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.

View file

@ -15,6 +15,7 @@ import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
import json
from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
@ -303,8 +304,9 @@ class FilenameGenerator:
'width': lambda self: self.image.width,
'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
@ -524,6 +526,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
image.save(fullfn, quality=opts.jpeg_quality)
image.already_saved_as = fullfn
target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
@ -550,10 +554,45 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn
def read_info_from_image(image):
items = image.info or {}
geninfo = items.pop('parameters', None)
if "exif" in items:
exif = piexif.load(items["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
items['exif comment'] = exif_comment
geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
json_info = json.loads(items["Comment"])
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
geninfo = f"""{items["Description"]}
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items
def image_data(data):
try:
image = Image.open(io.BytesIO(data))
textinfo = image.text["parameters"]
textinfo, _ = read_info_from_image(image)
return textinfo, None
except Exception:
pass

View file

@ -4,9 +4,9 @@ import sys
import traceback
import numpy as np
from PIL import Image, ImageOps, ImageChops
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
from modules import devices
from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@ -40,7 +40,7 @@ def process_batch(p, input_dir, output_dir, args):
img = Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args)
@ -59,18 +59,30 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
if is_inpaint:
# Drawn mask
if mask_mode == 0:
image = init_img_with_mask['image']
mask = init_img_with_mask['mask']
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert('RGB')
image = init_img_with_mask
is_mask_sketch = isinstance(image, dict)
is_mask_paint = not is_mask_sketch
if is_mask_sketch:
# Sketch: mask iff. not transparent
image, mask = image["image"], image["mask"]
pred = np.array(mask)[..., -1] > 0
else:
# Color-sketch: mask iff. painted over
orig = init_img_with_mask_orig or image
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
if is_mask_paint:
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
# Uploaded mask
else:
image = init_img_inpaint
@ -82,7 +94,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image)
image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
@ -99,7 +111,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index,
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,

View file

@ -148,8 +148,7 @@ class InterrogateModels:
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
with torch.no_grad(), devices.autocast():
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)

View file

@ -101,8 +101,8 @@ class LDSR:
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(wd)
height_downsampled_pre = int(hd)
width_downsampled_pre = int(np.ceil(wd))
height_downsampled_pre = int(np.ceil(hd))
if down_sample_rate != 1:
print(
@ -110,7 +110,12 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
@ -120,6 +125,9 @@ class LDSR:
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
# remove padding
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
del model
gc.collect()
torch.cuda.empty_cache()

View file

@ -3,6 +3,7 @@ import os
import sys
import traceback
localizations = {}
@ -16,6 +17,11 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file)
from modules import scripts
for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)

View file

@ -51,6 +51,10 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
@ -65,6 +69,10 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
del sd_model.cond_stage_model.transformer
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:

View file

@ -82,6 +82,7 @@ def cleanup_models():
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt")
move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)

View file

@ -1,14 +1,23 @@
from pyngrok import ngrok, conf, exception
def connect(token, port, region):
account = None
if token == None:
token = 'None'
else:
if ':' in token:
# token = authtoken:username:password
account = token.split(':')[1] + ':' + token.split(':')[-1]
token = token.split(':')[0]
config = conf.PyngrokConfig(
auth_token=token, region=region
)
try:
public_url = ngrok.connect(port, pyngrok_config=config).public_url
if account == None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
else:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')

View file

@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places
sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path)

View file

@ -2,6 +2,7 @@ import json
import math
import os
import sys
import warnings
import torch
import numpy as np
@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@ -91,7 +88,7 @@ class StableDiffusionProcessing():
self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_index: int = sampler_index
self.sampler_name: str = sampler_name
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
@ -116,6 +113,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.is_using_inpainting_conditioning = False
if not seed_enable_extras:
self.subseed = -1
@ -126,6 +124,7 @@ class StableDiffusionProcessing():
self.scripts = None
self.script_args = None
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
@ -136,6 +135,8 @@ class StableDiffusionProcessing():
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1)
self.is_using_inpainting_conditioning = True
height = height or self.height
width = width or self.width
@ -154,6 +155,8 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
if image_mask is not None:
if torch.is_tensor(image_mask):
@ -200,7 +203,7 @@ class StableDiffusionProcessing():
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@ -210,8 +213,7 @@ class Processed:
self.info = info
self.width = p.width
self.height = p.height
self.sampler_index = p.sampler_index
self.sampler = sd_samplers.samplers[p.sampler_index].name
self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
@ -238,17 +240,20 @@ class Processed:
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed]
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self):
obj = {
"prompt": self.prompt,
"prompt": self.all_prompts[0],
"all_prompts": self.all_prompts,
"negative_prompt": self.negative_prompt,
"negative_prompt": self.all_negative_prompts[0],
"all_negative_prompts": self.all_negative_prompts,
"seed": self.seed,
"all_seeds": self.all_seeds,
"subseed": self.subseed,
@ -256,8 +261,7 @@ class Processed:
"subseed_strength": self.subseed_strength,
"width": self.width,
"height": self.height,
"sampler_index": self.sampler_index,
"sampler": self.sampler,
"sampler_name": self.sampler_name,
"cfg_scale": self.cfg_scale,
"steps": self.steps,
"batch_size": self.batch_size,
@ -273,6 +277,7 @@ class Processed:
"styles": self.styles,
"job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
}
return json.dumps(obj)
@ -384,7 +389,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = {
"Steps": p.steps,
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
"Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@ -399,6 +404,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
@ -408,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@ -418,13 +424,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
res = process_images_inner(p)
finally:
finally: # restore opts to original state
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
return res
@ -437,10 +445,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
assert p.prompt is not None
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
devices.torch_gc()
seed = get_fixed_seed(p.seed)
@ -451,12 +455,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
comments = {}
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
p.all_prompts = p.prompt
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
else:
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
if type(p.negative_prompt) == list:
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
else:
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
if type(seed) == list:
p.all_seeds = seed
@ -471,6 +478,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@ -495,14 +506,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
break
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if len(prompts) == 0:
break
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@ -515,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
@ -588,7 +603,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
@ -642,7 +657,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -665,17 +680,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
for i in range(samples.shape[0]):
save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -703,7 +718,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -727,7 +742,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
#self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
@ -740,39 +754,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
if self.image_mask is not None:
self.image_mask = self.image_mask.convert('L')
image_mask = self.image_mask
if image_mask is not None:
image_mask = image_mask.convert('L')
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
#self.image_unblurred_mask = self.image_mask
image_mask = ImageOps.invert(image_mask)
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask
mask = self.image_mask.convert('L')
self.mask_for_overlay = image_mask
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height)
image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
np_mask = np.array(self.image_mask)
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections:
@ -784,7 +798,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if crop_region is None:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if self.image_mask is not None:
if image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
@ -794,7 +808,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
if self.image_mask is not None:
if image_mask is not None:
if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask)
@ -826,7 +840,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.image_mask is not None:
if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
@ -843,7 +857,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

View file

@ -23,11 +23,18 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
if self.extra_handler is not None:
res = self.extra_handler(module, name)
if res is not None:
return res
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@ -52,32 +59,37 @@ class RestrictedUnpickler(pickle.Unpickler):
return set
# Forbid everything else.
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"]
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
if name in allowed_zip_names:
continue
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename):
def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
with z.open('archive/data.pkl') as file:
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
@ -85,16 +97,42 @@ def check_pt(filename):
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
"""
this functon is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
and returns that field's value:
```python
def extra(module, name):
if module == 'collections' and name == 'OrderedDict':
return collections.OrderedDict
return None
safe.load_with_extra('model.pt', extra_handler=extra)
```
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
definitely unsafe.
"""
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename)
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)

View file

@ -7,6 +7,7 @@ from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@ -45,26 +46,33 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned"""
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
self.txt2img_preview_params = txt2img_preview_params
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
callbacks_cfg_denoiser = []
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
callbacks_ui_train_tabs=[],
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
)
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
callbacks_cfg_denoiser.clear()
for callback_list in callback_map.values():
callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callbacks_app_started:
for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@ -72,7 +80,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
for c in callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@ -81,8 +89,8 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
for c in callbacks_ui_tabs:
for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@ -91,8 +99,16 @@ def ui_tabs_callback():
return res
def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']:
try:
c.callback(params)
except Exception:
report_exception(c, 'callbacks_ui_train_tabs')
def ui_settings_callback():
for c in callbacks_ui_settings:
for c in callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@ -100,7 +116,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_before_image_saved:
for c in callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@ -108,7 +124,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved:
for c in callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@ -116,30 +132,62 @@ def image_saved_callback(params: ImageSaveParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callbacks_cfg_denoiser:
for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_denoiser_callback')
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
c.callback(component, **kwargs)
except Exception:
report_exception(c, 'before_component_callback')
def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']:
try:
c.callback(component, **kwargs)
except Exception:
report_exception(c, 'after_component_callback')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
if filename == 'unknown file':
return
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
callback_list.remove(callback_to_remove)
def remove_callbacks_for_function(callback_func):
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
add_callback(callbacks_model_loaded, callback)
add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@ -152,13 +200,20 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
add_callback(callbacks_ui_tabs, callback)
add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_train_tabs(callback):
"""register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab.
"""
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callbacks_ui_settings, callback)
add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@ -166,7 +221,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
add_callback(callbacks_before_image_saved, callback)
add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@ -174,7 +229,7 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@ -182,5 +237,21 @@ def on_cfg_denoiser(callback):
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callbacks_cfg_denoiser, callback)
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
- component - gradio component that is about to be created.
- **kwargs - args to gradio.components.IOComponent.__init__ function
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
"""
add_callback(callback_map['callbacks_before_component'], callback)
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)

34
modules/script_loading.py Normal file
View file

@ -0,0 +1,34 @@
import os
import sys
import traceback
from types import ModuleType
def load_module(path):
with open(path, "r", encoding="utf8") as file:
text = file.read()
compiled = compile(text, path, 'exec')
module = ModuleType(os.path.basename(path))
exec(compiled, module.__dict__)
return module
def preload_extensions(extensions_dir, parser):
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script):
continue
try:
module = load_module(preload_script)
if hasattr(module, 'preload'):
module.preload(parser)
except Exception:
print(f"Error running preload() for {preload_script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)

View file

@ -3,11 +3,10 @@ import sys
import traceback
from collections import namedtuple
import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions
from modules import shared, paths, script_callbacks, extensions, script_loading
AlwaysVisible = object()
@ -18,6 +17,9 @@ class Script:
args_to = None
alwayson = False
is_txt2img = False
is_img2img = False
"""A gr.Group component that has all script's UI inside it"""
group = None
@ -73,6 +75,19 @@ class Script:
pass
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
"""
pass
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@ -81,6 +96,23 @@ class Script:
pass
def before_component(self, component, **kwargs):
"""
Called before a component is created.
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
You can return created components in the ui() function to add them to the list of arguments for your processing functions
"""
pass
def after_component(self, component, **kwargs):
"""
Called after a component is created. Same as above.
"""
pass
def describe(self):
"""unused"""
return ""
@ -128,7 +160,7 @@ def list_files_with_name(filename):
continue
path = os.path.join(dirpath, filename)
if os.path.isfile(filename):
if os.path.isfile(path):
res.append(path)
return res
@ -149,13 +181,7 @@ def load_scripts():
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
compiled = compile(text, scriptfile.path, 'exec')
module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
module = script_loading.load_module(scriptfile.path)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@ -189,12 +215,18 @@ class ScriptRunner:
self.titles = []
self.infotext_fields = []
def setup_ui(self, is_img2img):
def initialize_scripts(self, is_img2img):
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
visibility = script.show(is_img2img)
visibility = script.show(script.is_img2img)
if visibility == AlwaysVisible:
self.scripts.append(script)
@ -205,6 +237,7 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
def setup_ui(self):
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
inputs = [None]
@ -214,7 +247,7 @@ class ScriptRunner:
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None:
return
@ -296,6 +329,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:
@ -305,33 +347,44 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
print(f"Error running before_component: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
print(f"Error running after_component: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
args_to = script.args_to
filename = script.filename
text = file.read()
args_from = script.args_from
args_to = script.args_to
filename = script.filename
from types import ModuleType
module = cache.get(filename, None)
if module is None:
module = script_loading.load_module(script.filename)
cache[filename] = module
module = cache.get(filename, None)
if module is None:
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
scripts_current: ScriptRunner = None
def reload_script_body_only():
@ -348,3 +401,22 @@ def reload_scripts():
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
scripts_current.after_component(self, **kwargs)
return res
original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init

View file

@ -54,7 +54,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), device)
img = img.unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)

View file

@ -8,17 +8,32 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import opts, device, cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
def apply_optimizations():
undo_optimizations()
@ -47,16 +62,15 @@ def apply_optimizations():
def undo_optimizations():
from modules.hypernetworks import hypernetwork
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
def fix_checkpoint():
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack:
fixes = None
@ -68,14 +82,18 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
apply_optimizations()
fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@ -87,15 +105,18 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
self.apply_circular(False)
self.layers = None
self.circular_enabled = False
self.clip = None
def apply_circular(self, enable):
@ -112,262 +133,9 @@ class StableDiffusionModelHijack:
def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):
@ -406,3 +174,19 @@ def add_circular_option_to_conv_2d():
model_hijack = StableDiffusionModelHijack()
def register_buffer(self, name, attr):
"""
Fix register buffer bug for Mac OS.
"""
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer

View file

@ -0,0 +1,10 @@
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)

301
modules/sd_hijack_clip.py Normal file
View file

@ -0,0 +1,301 @@
import math
import torch
from modules import prompt_parser, devices
from modules.shared import opts
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
def tokenize(self, texts):
raise NotImplementedError
def encode_with_transformers(self, tokens):
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
raise NotImplementedError
def tokenize_line(self, line, used_custom_terms, hijack_comments):
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [self.id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [self.id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, texts):
id_start = self.id_start
id_end = self.id_end
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.tokenize(texts)
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.id_end] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
self.id_start = self.wrapped.tokenizer.bos_token_id
self.id_end = self.wrapped.tokenizer.eos_token_id
self.id_pad = self.id_end
def tokenize(self, texts):
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
def encode_with_transformers(self, tokens):
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
return embedded

View file

@ -199,8 +199,8 @@ def sample_plms(self,
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
@ -321,11 +323,16 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
# most of this stuff seems to no longer be needed because it is already included into SD2.0
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
# this file should be cleaned up later if weverything tuens out to work fine
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
# ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms

View file

@ -0,0 +1,37 @@
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
tokenizer = open_clip.tokenizer._tokenizer
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
self.id_start = tokenizer.encoder["<start_of_text>"]
self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0
def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [tokenizer.encode(text) for text in texts]
return tokenized
def encode_with_transformers(self, tokens):
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
z = self.wrapped.encode_with_transformer(tokens)
return z
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded

View file

@ -5,6 +5,7 @@ import gc
from collections import namedtuple
import torch
import re
import safetensors.torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
@ -45,7 +46,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
@ -143,8 +144,8 @@ def transform_checkpoint_dict_key(k):
def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
pl_sd = pl_sd["state_dict"]
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
sd = {}
for k, v in pl_sd.items():
@ -159,25 +160,41 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
if print_global_state and "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
return sd
def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
cache_enabled = shared.opts.sd_checkpoint_cache > 0
checkpoint_key = checkpoint_info
if checkpoint_key not in checkpoints_loaded:
if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
sd = read_state_dict(checkpoint_file)
model.load_state_dict(sd, strict=False)
del sd
if cache_enabled:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@ -197,23 +214,16 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
# if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else:
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
checkpoints_loaded.move_to_end(checkpoint_key)
model.load_state_dict(checkpoints_loaded[checkpoint_key])
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)
@ -244,6 +254,9 @@ def load_model(checkpoint_info=None):
do_inpainting_hijack()
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)

View file

@ -1,4 +1,4 @@
from collections import namedtuple
from collections import namedtuple, deque
import numpy as np
from math import floor
import torch
@ -6,6 +6,7 @@ import tqdm
from PIL import Image
import inspect
import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images
@ -18,17 +19,23 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
@ -42,16 +49,24 @@ all_samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
def create_sampler_with_index(list_of_configs, index, model):
config = list_of_configs[index]
def create_sampler(name, model):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
sampler.config = config
return sampler
@ -64,6 +79,12 @@ def set_samplers():
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_map.clear()
for sampler in all_samplers:
samplers_map[sampler.name.lower()] = sampler.name
for alias in sampler.aliases:
samplers_map[alias.lower()] = sampler.name
set_samplers()
@ -116,7 +137,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
@ -207,7 +229,6 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
@ -216,7 +237,6 @@ class VanillaStableDiffusionSampler:
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
@ -249,9 +269,10 @@ class VanillaStableDiffusionSampler:
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@ -324,28 +345,55 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack:
def __init__(self, kdiff_sampler):
self.kdiff_sampler = kdiff_sampler
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item):
if item == 'randn_like':
return self.kdiff_sampler.randn_like
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
if x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
# MPS fix for randn in torchsde
def torchsde_randn(size, dtype, device, seed):
if device.type == 'mps':
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
else:
generator = torch.Generator(device).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=device, generator=generator)
torchsde._brownian.brownian_interval._randn = torchsde_randn
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.sampler_noise_index = 0
self.stop_at = None
self.eta = None
self.default_eta = 1.0
@ -378,26 +426,13 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p):
return p.steps
def randn_like(self, x):
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
if noise is not None and x.shape == noise.shape:
res = noise
else:
res = torch.randn_like(x)
self.sampler_noise_index += 1
return res
def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
extra_params_kwargs = {}
for param_name in self.extra_params:

View file

@ -83,47 +83,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
def get_vae_from_settings(vae_file="auto"):
# else, we load from settings, if not set to be default
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
print(f"Selected VAE doesn't exist: {vae_file}")
return vae_file
def resolve_vae(checkpoint_file=None, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
print(f"VAE provided as function argument doesn't exist: {vae_file}")
vae_file = "auto"
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
# fallback to selector in settings, if vae selector not set to act as default fallback
if not shared.opts.sd_vae_as_default:
vae_file = get_vae_from_settings(vae_file)
# vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument")
print(f"Using VAE provided as command line argument: {vae_file}")
# if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto":
vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
@ -139,6 +146,7 @@ def load_vae(model, vae_file=None):
# save_settings = False
if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}

View file

@ -3,7 +3,6 @@ import datetime
import json
import os
import sys
from collections import OrderedDict
import time
import gradio as gr
@ -12,17 +11,18 @@ import tqdm
import modules.artists
import modules.interrogate
import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae
from modules.hypernetworks import hypernetwork
from modules import localization, sd_vae, extensions, script_loading
from modules.paths import models_path, script_path, sd_path
demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@ -44,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@ -55,7 +56,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
@ -71,6 +72,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@ -80,13 +82,22 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
script_loading.preload_extensions(extensions.extensions_dir, parser)
cmd_opts = parser.parse_args()
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
@ -99,7 +110,7 @@ restricted_opts = {
"outdir_save",
}
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
@ -113,10 +124,12 @@ xformers_available = False
config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
hypernetworks = {}
loaded_hypernetwork = None
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
@ -146,6 +159,9 @@ class State:
self.interrupted = True
def nextjob(self):
if opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
@ -186,17 +202,22 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image()
def do_set_current_image(self):
if not parallel_processing_allowed:
return
if self.current_latent is None:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else:
self.current_image = sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
import modules.sd_samplers
if opts.show_progress_grid:
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
else:
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
state = State()
@ -209,8 +230,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
localization.list_localizations(cmd_opts.localizations_dir)
def realesrgan_models_names():
import modules.realesrgan_model
@ -235,6 +254,21 @@ def options_section(section_identifier, options_dict):
return options_dict
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
options_templates = {}
@ -263,6 +297,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@ -287,7 +325,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
@ -309,6 +347,8 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@ -317,9 +357,10 @@ options_templates.update(options_section(('training', "Training"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@ -331,7 +372,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
@ -351,7 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@ -368,7 +409,7 @@ options_templates.update(options_section(('ui', "User interface"), {
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
@ -398,7 +439,8 @@ class Options:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
comp_args = opts.data_labels[key].component_args
info = opts.data_labels.get(key, None)
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
@ -420,6 +462,23 @@ class Options:
return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
self.data_labels[key].onchange()
return True
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"

View file

@ -65,17 +65,6 @@ class StyleDatabase:
def apply_negative_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
def apply_styles(self, p: StableDiffusionProcessing) -> None:
if isinstance(p.prompt, list):
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
else:
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
if isinstance(p.negative_prompt, list):
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
else:
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv")

View file

@ -13,10 +13,6 @@ from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
@ -111,8 +107,8 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
img = img.unsqueeze(0).to(devices.device_swinir)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old

View file

@ -276,8 +276,8 @@ def poi_average(pois, settings):
weight += poi.weight
x += poi.x * poi.weight
y += poi.y * poi.weight
avg_x = round(x / weight)
avg_y = round(y / weight)
avg_x = round(weight and x / weight)
avg_y = round(weight and y / weight)
return PointOfInterest(avg_x, avg_y)
@ -338,4 +338,4 @@ class Settings:
self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
self.destop_view_image = False
self.dnn_model_path = dnn_model_path
self.dnn_model_path = dnn_model_path

View file

@ -3,7 +3,7 @@ import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import random
@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared
import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@ -71,53 +79,94 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
latent_sample = None
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if include_cond:
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = np.random.permutation(self.dataset_length)
self.length = len(self.dataset)
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
return self.length
def __getitem__(self, i):
res = []
entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
for j in range(self.batch_size):
position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
index = self.indexes[position % len(self.indexes)]
entry = self.dataset[index]
class BatchLoader:
def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
res.append(entry)
def collate_wrapper(batch):
return BatchLoader(batch)
return res
class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
def pin_memory(self):
return self
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)

View file

@ -6,12 +6,10 @@ import sys
import tqdm
import time
from modules import shared, images
from modules import shared, images, deepbooru
from modules.paths import models_path
from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
shared.interrogator.load()
if process_caption_deepbooru:
db_opts = deepbooru.create_deepbooru_opts()
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
deepbooru.model.start()
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
@ -32,9 +28,87 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
deepbooru.model.stop()
def listfiles(dirname):
return os.listdir(dirname)
class PreprocessParams:
src = None
dstdir = None
subindex = 0
flip = False
process_caption = False
process_caption_deepbooru = False
preprocess_txt_action = None
def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
caption = ""
if params.process_caption:
caption += shared.interrogator.generate_caption(image)
if params.process_caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.model.tag_multi(image)
filename_part = params.src
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{params.subindex}-{filename_part}"
image.save(os.path.join(params.dstdir, f"{basename}.png"))
if params.preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif params.preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif params.preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0:
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
params.subindex += 1
def save_pic(image, index, params, existing_caption=None):
save_pic_with_caption(image, index, params, existing_caption=existing_caption)
if params.flip:
save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
def split_pic(image, inverse_xy, width, height, overlap_ratio):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
@ -48,82 +122,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
os.makedirs(dst, exist_ok=True)
files = os.listdir(src)
files = listfiles(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
caption += shared.interrogator.generate_caption(image)
if process_caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
filename_part = filename
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
def split_pic(image, inverse_xy):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
params = PreprocessParams()
params.dstdir = dst
params.flip = process_flip
params.process_caption = process_caption
params.process_caption_deepbooru = process_caption_deepbooru
params.preprocess_txt_action = preprocess_txt_action
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
params.subindex = 0
filename = os.path.join(src, imagefile)
try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
params.src = filename
existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename):
@ -143,8 +163,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
save_pic(splitted, index, params, existing_caption=existing_caption)
process_default_resize = False
if process_focal_crop and img.height != img.width:
@ -165,11 +185,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption)
save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption)
save_pic(img, index, params, existing_caption=existing_caption)
shared.state.nextjob()
shared.state.nextjob()

View file

@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models, images
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -64,7 +64,8 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
@ -155,13 +156,11 @@ class EmbeddingDatabase:
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
@ -184,7 +183,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if (step + 1) % shared.opts.training_write_csv_every != 0:
if step % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
@ -194,21 +193,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header:
csv_writer.writeheader()
epoch = step // epoch_len
epoch_step = step % epoch_len
epoch = (step - 1) // epoch_len
epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
"step": step + 1,
"step": step,
"epoch": epoch,
"epoch_step": epoch_step + 1,
"epoch_step": epoch_step,
**values,
})
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
@ -224,10 +225,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@ -255,166 +256,203 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step >= steps:
initial_step = embedding.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
old_parallel_processing_allowed = shared.parallel_processing_allowed
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
scaler = torch.cuda.amp.GradScaler()
losses = torch.zeros((32,))
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
scaler.step(optimizer)
scaler.update()
embedding.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
steps_done = embedding.step + 1
if shared.state.interrupted:
break
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
#if shared.opts.save_optimizer_state:
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
losses[embedding.step % losses.shape[0]] = loss.item()
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
optimizer.zero_grad()
loss.backward()
optimizer.step()
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
steps_done = embedding.step + 1
shared.sd_model.first_stage_model.to(devices.device)
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step % len(ds)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
preview_text = p.prompt
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
})
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.device)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
preview_text = p.prompt
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
processed = processing.process_images(p)
image = processed.images[0]
title = "<{}>".format(data.get('name', '???'))
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
shared.state.current_image = image
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
title = "<{}>".format(data.get('name', '???'))
shared.state.job_no = embedding.step
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename

View file

@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
return "Preprocessing finished.", ""
return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args):

View file

@ -1,4 +1,5 @@
import modules.scripts
from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index,
sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,

View file

@ -17,16 +17,13 @@ import gradio.routes
import gradio.utils
import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model
@ -69,8 +66,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
"""
# Using constants for these since the variation selector isn't visible.
@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip
if do_make_zip:
@ -157,81 +157,7 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def save_pil_to_file(pil_image, dir=None):
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if (elapsed_m > 0):
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res)
return f
def calc_time_left(progress, threshold, label, force_display):
@ -276,7 +202,7 @@ def check_progress_call(id_part):
image = gr_show(False)
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
if opts.show_progress_every_n_steps != 0:
shared.state.set_current_image()
image = shared.state.current_image
@ -346,7 +272,7 @@ def interrogate(image):
def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image)
prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt
@ -475,9 +401,7 @@ def create_toprow(is_img2img):
if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
if cmd_opts.deepdanbooru:
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
with gr.Row():
@ -563,6 +487,19 @@ def apply_setting(key, value):
return value
def update_generation_info(args):
generation_info, html_info, img_index = args
try:
generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
return html_info
return plaintext_to_html(generation_info["infotexts"][img_index])
except Exception:
pass
# if the json parse or anything else fails, just return the old html_info
return html_info
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@ -635,6 +572,15 @@ Requested path was: {f}
with gr.Group():
html_info = gr.HTML()
generation_info = gr.Textbox(visible=False)
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
_js="(x, y) => [x, y, selected_gallery_index()]",
inputs=[generation_info, html_info],
outputs=[html_info],
preprocess=False
)
save.click(
fn=wrap_gradio_call(save_files),
@ -659,7 +605,7 @@ Requested path was: {f}
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
def create_ui(wrap_gradio_gpu_call):
def create_ui():
import modules.img2img
import modules.txt2img
@ -667,6 +613,9 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.reset()
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
@ -709,7 +658,7 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@ -816,7 +765,10 @@ def create_ui(wrap_gradio_gpu_call):
height,
]
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
@ -840,11 +792,22 @@ def create_ui(wrap_gradio_gpu_call):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
init_img_with_mask_orig = gr.State(None)
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
def update_orig(image, state):
if image is not None:
same_size = state is not None and state.size == image.size
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
edited = same_size and has_exact_match
return image if not edited or state is None else state
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
show_mask_alpha = cmd_opts.gradio_inpaint_tool == "color-sketch"
mask_alpha = gr.Slider(label="Mask transparency", interactive=show_mask_alpha, visible=show_mask_alpha)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
with gr.Row():
@ -888,7 +851,7 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
@ -932,12 +895,14 @@ def create_ui(wrap_gradio_gpu_call):
img2img_prompt_style2,
init_img,
init_img_with_mask,
init_img_with_mask_orig,
init_img_inpaint,
init_mask_inpaint,
mask_mode,
steps,
sampler_index,
mask_blur,
mask_alpha,
inpainting_fill,
restore_faces,
tiling,
@ -973,11 +938,10 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt],
)
if cmd_opts.deepdanbooru:
img2img_deepbooru.click(
fn=interrogate_deepbooru,
inputs=[init_img],
outputs=[img2img_prompt],
img2img_deepbooru.click(
fn=interrogate_deepbooru,
inputs=[init_img],
outputs=[img2img_prompt],
)
@ -1032,11 +996,14 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@ -1138,7 +1105,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[html, generation_info, html2],
)
with gr.Blocks() as modelmerger_interface:
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
@ -1150,7 +1117,11 @@ def create_ui(wrap_gradio_gpu_call):
custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
@ -1158,7 +1129,7 @@ def create_ui(wrap_gradio_gpu_call):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as train_interface:
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
@ -1180,7 +1151,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
@ -1207,7 +1178,7 @@ def create_ui(wrap_gradio_gpu_call):
process_split = gr.Checkbox(label='Split oversized images')
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True)
with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
@ -1224,6 +1195,8 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change(
@ -1251,6 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@ -1261,12 +1235,21 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
script_callbacks.ui_train_tabs_callback(params)
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
@ -1345,11 +1328,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name,
embedding_learn_rate,
batch_size,
gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
@ -1370,11 +1357,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
@ -1393,6 +1384,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
interrupt_preprocessing.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -1417,15 +1414,14 @@ def create_ui(wrap_gradio_gpu_call):
if info.refresh is not None:
if is_quicksettings:
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
with gr.Row(variant="compact"):
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
return res
@ -1436,43 +1432,30 @@ def create_ui(wrap_gradio_gpu_call):
opts.reorder()
def run_settings(*args):
changed = 0
changed = []
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
continue
oldval = opts.data.get(key, None)
if opts.set(key, value):
changed.append(key)
setattr(opts, key, value)
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
changed += 1
opts.save(shared.config_filename)
return f'{changed} settings changed.', opts.dumpjson()
try:
opts.save(shared.config_filename)
except RuntimeError:
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
try:
setattr(opts, key, value)
except Exception:
return gr.update(value=oldval), opts.dumpjson()
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
if not opts.set(key, value):
return gr.update(value=getattr(opts, key)), opts.dumpjson()
opts.save(shared.config_filename)
@ -1562,11 +1545,10 @@ def create_ui(wrap_gradio_gpu_call):
shared.state.need_restart = True
restart_gradio.click(
fn=request_restart,
_js='restart_reload',
inputs=[],
outputs=[],
_js='restart_reload'
)
if column is not None:
@ -1622,9 +1604,9 @@ def create_ui(wrap_gradio_gpu_call):
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
fn=run_settings,
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components,
outputs=[result, text_settings],
outputs=[text_settings, result],
)
for i, k, item in quicksettings_list:
@ -1636,6 +1618,17 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[component, text_settings],
)
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_settings_values():
return [getattr(opts, key) for key in component_keys]
demo.load(
fn=get_settings_values,
inputs=[],
outputs=[component_dict[k] for k in component_keys],
)
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
@ -1656,6 +1649,7 @@ def create_ui(wrap_gradio_gpu_call):
interp_amount,
save_as_half,
custom_name,
checkpoint_format,
],
outputs=[
submit_result,
@ -1739,7 +1733,7 @@ def create_ui(wrap_gradio_gpu_call):
return demo
def load_javascript(raw_response):
def reload_javascript():
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
@ -1755,7 +1749,7 @@ def load_javascript(raw_response):
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
def template_response(*args, **kwargs):
res = raw_response(*args, **kwargs)
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
@ -1764,4 +1758,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse

View file

@ -17,7 +17,7 @@ available_extensions = {"extensions": []}
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
def apply_and_restart(disable_list, update_list):
@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
continue
try:
ext.pull()
ext.fetch_and_reset_hard()
except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
print(f"Error getting updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
@ -86,7 +86,7 @@ def extension_table():
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@ -134,19 +134,24 @@ def install_extension_from_url(dirname, url):
os.rename(tmpdir, target_dir)
import launch
launch.run_extension_installer(target_dir)
extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
finally:
shutil.rmtree(tmpdir, True)
def install_extension_from_index(url):
def install_extension_from_index(url, hide_tags):
ext_table, message = install_extension_from_url(None, url)
return refresh_available_extensions_from_data(), ext_table, message
code, _ = refresh_available_extensions_from_data(hide_tags)
return code, ext_table, message
def refresh_available_extensions(url):
def refresh_available_extensions(url, hide_tags):
global available_extensions
import urllib.request
@ -155,13 +160,25 @@ def refresh_available_extensions(url):
available_extensions = json.loads(text)
return url, refresh_available_extensions_from_data(), ''
code, tags = refresh_available_extensions_from_data(hide_tags)
return url, code, gr.CheckboxGroup.update(choices=tags), ''
def refresh_available_extensions_from_data():
def refresh_available_extensions_for_tags(hide_tags):
code, _ = refresh_available_extensions_from_data(hide_tags)
return code, ''
def refresh_available_extensions_from_data(hide_tags):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
tags = available_extensions.get("tags", {})
tags_to_hide = set(hide_tags)
hidden = 0
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
@ -178,17 +195,24 @@ def refresh_available_extensions_from_data():
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
if url is None:
continue
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
hidden += 1
continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
code += f"""
<tr>
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
@ -199,7 +223,10 @@ def refresh_available_extensions_from_data():
</table>
"""
return code
if hidden > 0:
code += f"<p>Extension hidden: {hidden}</p>"
return code, list(tags)
def create_ui():
@ -238,21 +265,30 @@ def create_ui():
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
inputs=[available_extensions_index],
outputs=[available_extensions_index, available_extensions_table, install_result],
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install],
inputs=[extension_to_install, hide_tags],
outputs=[available_extensions_table, extensions_table, install_result],
)
hide_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[hide_tags],
outputs=[available_extensions_table, install_result]
)
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")

62
modules/ui_tempdir.py Normal file
View file

@ -0,0 +1,62 @@
import os
import tempfile
from collections import namedtuple
import gradio as gr
from PIL import PngImagePlugin
from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
file_obj = Savedfile(already_saved_as)
return file_obj
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def on_tmpdir_changed():
if shared.opts.temp_dir == "" or shared.demo is None:
return
os.makedirs(shared.opts.temp_dir, exist_ok=True)
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
return
for root, dirs, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
continue
filename = os.path.join(root, name)
os.remove(filename)

View file

@ -57,10 +57,18 @@ class Upscaler:
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
if img.width > dest_w and img.height > dest_h:
break
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
if shape == (img.width, img.height):
break
if img.width >= dest_w and img.height >= dest_h:
break
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)