mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-12-24 19:01:06 -08:00
Merge branch 'master' into racecond_fix
This commit is contained in:
commit
c9a2cfdf2a
93 changed files with 3589 additions and 8465 deletions
|
|
@ -2,14 +2,22 @@ import base64
|
|||
import io
|
||||
import time
|
||||
import uvicorn
|
||||
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from typing import List
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
|
|
@ -18,8 +26,12 @@ def upscaler_to_index(name: str):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
return name
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
|
|
@ -29,39 +41,84 @@ def setUpscalers(req: dict):
|
|||
reqDict.pop('upscaler_2')
|
||||
return reqDict
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
return Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="png")
|
||||
return base64.b64encode(buffer.getvalue())
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
# Copy any text-only metadata
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
image.save(
|
||||
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app, queue_lock):
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
if shared.cmd_opts.api_auth:
|
||||
self.credenticals = dict()
|
||||
for auth in shared.cmd_opts.api_auth.split(","):
|
||||
user, password = auth.split(":")
|
||||
self.credenticals[user] = password
|
||||
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
|
||||
if credenticals.username in self.credenticals:
|
||||
if compare_digest(credenticals.password, self.credenticals[credenticals.username]):
|
||||
return True
|
||||
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
}
|
||||
)
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
|
||||
|
|
@ -77,12 +134,6 @@ class Api:
|
|||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(img2imgreq.sampler_index)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
|
@ -91,16 +142,20 @@ class Api:
|
|||
if mask:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
}
|
||||
)
|
||||
p = StableDiffusionProcessingImg2Img(**vars(populate))
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
p = StableDiffusionProcessingImg2Img(**args)
|
||||
|
||||
imgs = []
|
||||
for img in init_images:
|
||||
|
|
@ -118,7 +173,7 @@ class Api:
|
|||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
if (not img2imgreq.include_init_images):
|
||||
if not img2imgreq.include_init_images:
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
|
|
@ -186,11 +241,92 @@ class Api:
|
|||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = decode_base64_to_image(image_b64)
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
if interrogatereq.model == "clip":
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
elif interrogatereq.model == "deepdanbooru":
|
||||
processed = deepbooru.model.tag(img)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
def get_config(self):
|
||||
options = {}
|
||||
for key in shared.opts.data.keys():
|
||||
metadata = shared.opts.data_labels.get(key)
|
||||
if(metadata is not None):
|
||||
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
|
||||
else:
|
||||
options.update({key: shared.opts.data.get(key, None)})
|
||||
|
||||
return options
|
||||
|
||||
def set_config(self, req: Dict[str, Any]):
|
||||
for k, v in req.items():
|
||||
shared.opts.set(k, v)
|
||||
|
||||
shared.opts.save(shared.config_filename)
|
||||
return
|
||||
|
||||
def get_cmd_flags(self):
|
||||
return vars(shared.cmd_opts)
|
||||
|
||||
def get_samplers(self):
|
||||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_upscalers(self):
|
||||
upscalers = []
|
||||
|
||||
for upscaler in shared.sd_upscalers:
|
||||
u = upscaler.scaler
|
||||
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
||||
|
||||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
||||
def get_face_restorers(self):
|
||||
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
||||
|
||||
def get_realesrgan_models(self):
|
||||
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
||||
|
||||
def get_promp_styles(self):
|
||||
styleList = []
|
||||
for k in shared.prompt_styles.styles:
|
||||
style = shared.prompt_styles.styles[k]
|
||||
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
|
||||
|
||||
return styleList
|
||||
|
||||
def get_artists_categories(self):
|
||||
return shared.artist_db.cats
|
||||
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import inspect
|
||||
from click import prompt
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers
|
||||
from modules.shared import sd_upscalers, opts, parser
|
||||
from typing import Dict, List
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
|
|
@ -65,6 +65,7 @@ class PydanticModelGenerator:
|
|||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
|
|
@ -109,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
|||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
|
|
@ -147,10 +148,10 @@ class FileData(BaseModel):
|
|||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
|
@ -166,3 +167,76 @@ class ProgressResponse(BaseModel):
|
|||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
fields = {}
|
||||
for key, metadata in opts.data_labels.items():
|
||||
value = opts.data.get(key)
|
||||
optType = opts.typemap.get(type(metadata.default), type(value))
|
||||
|
||||
if (metadata is not None):
|
||||
fields.update({key: (Optional[optType], Field(
|
||||
default=metadata.default ,description=metadata.label))})
|
||||
else:
|
||||
fields.update({key: (Optional[optType], Field())})
|
||||
|
||||
OptionsModel = create_model("Options", **fields)
|
||||
|
||||
flags = {}
|
||||
_options = vars(parser)['_option_string_actions']
|
||||
for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str
|
||||
if _options[key].default is not None: _type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
class SamplerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: List[str] = Field(title="Aliases")
|
||||
options: Dict[str, str] = Field(title="Options")
|
||||
|
||||
class UpscalerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: str = Field(title="Hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
|
||||
class FaceRestorerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
cmd_dir: Optional[str] = Field(title="Path")
|
||||
|
||||
class RealesrganItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
scale: Optional[int] = Field(title="Scale")
|
||||
|
||||
class PromptStyleItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
|
|
|
|||
98
modules/call_queue.py
Normal file
98
modules/call_queue.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import html
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
def f(*args, **kwargs):
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
||||
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
return tuple(res)
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||
|
||||
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
||||
|
|
@ -36,6 +36,7 @@ def setup_model(dirname):
|
|||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from facelib.detection.retinaface import retinaface
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
net_class = CodeFormer
|
||||
|
|
@ -65,6 +66,8 @@ def setup_model(dirname):
|
|||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = devices.device_codeformer
|
||||
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
||||
|
||||
self.net = net
|
||||
|
|
|
|||
|
|
@ -1,173 +1,97 @@
|
|||
import os.path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import multiprocessing
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
def get_deepbooru_tags(pil_image):
|
||||
"""
|
||||
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
|
||||
try:
|
||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
||||
return get_tags_from_process(pil_image)
|
||||
finally:
|
||||
release_process()
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def load(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
OPT_INCLUDE_RANKS = "include_ranks"
|
||||
def create_deepbooru_opts():
|
||||
from modules import shared
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
||||
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||
ext_filter=".pt",
|
||||
download_name='model-resnet_custom_v3.pt',
|
||||
)
|
||||
|
||||
return {
|
||||
"use_spaces": shared.opts.deepbooru_use_spaces,
|
||||
"use_escape": shared.opts.deepbooru_escape,
|
||||
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
||||
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
|
||||
}
|
||||
self.model = deepbooru_model.DeepDanbooruModel()
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
|
||||
self.model.eval()
|
||||
self.model.to(devices.cpu, devices.dtype)
|
||||
|
||||
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
||||
model, tags = get_deepbooru_tags_model()
|
||||
while True: # while process is running, keep monitoring queue for new image
|
||||
pil_image = queue.get()
|
||||
if pil_image == "QUIT":
|
||||
break
|
||||
else:
|
||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
||||
def start(self):
|
||||
self.load()
|
||||
self.model.to(devices.device)
|
||||
|
||||
def stop(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
def create_deepbooru_process(threshold, deepbooru_opts):
|
||||
"""
|
||||
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
||||
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
||||
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
||||
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
||||
the tags.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
context = multiprocessing.get_context("spawn")
|
||||
shared.deepbooru_process_manager = context.Manager()
|
||||
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
||||
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
||||
shared.deepbooru_process.start()
|
||||
def tag(self, pil_image):
|
||||
self.start()
|
||||
res = self.tag_multi(pil_image)
|
||||
self.stop()
|
||||
|
||||
return res
|
||||
|
||||
def get_tags_from_process(image):
|
||||
from modules import shared
|
||||
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
||||
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
caption = shared.deepbooru_process_return["value"]
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
return caption
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
def release_process():
|
||||
"""
|
||||
Stops the deepbooru process to return used memory
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_queue.put("QUIT")
|
||||
shared.deepbooru_process.join()
|
||||
shared.deepbooru_process_queue = None
|
||||
shared.deepbooru_process = None
|
||||
shared.deepbooru_process_return = None
|
||||
shared.deepbooru_process_manager = None
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if probability < threshold:
|
||||
continue
|
||||
|
||||
def get_deepbooru_tags_model():
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
this_folder = os.path.dirname(__file__)
|
||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||
# there is no point importing these every time
|
||||
import zipfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
load_file_from_url(
|
||||
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||
model_path)
|
||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||
|
||||
tags = dd.project.load_tags_from_project(model_path)
|
||||
model = dd.project.load_model_from_project(
|
||||
model_path, compile_model=False
|
||||
)
|
||||
return model, tags
|
||||
|
||||
|
||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
alpha_sort = deepbooru_opts['alpha_sort']
|
||||
use_spaces = deepbooru_opts['use_spaces']
|
||||
use_escape = deepbooru_opts['use_escape']
|
||||
include_ranks = deepbooru_opts['include_ranks']
|
||||
|
||||
width = model.input_shape[2]
|
||||
height = model.input_shape[1]
|
||||
image = np.array(pil_image)
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
size=(height, width),
|
||||
method=tf.image.ResizeMethod.AREA,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
image = image.numpy() # EagerTensor to np.array
|
||||
image = dd.image.transform_and_pad_image(image, width, height)
|
||||
image = image / 255.0
|
||||
image_shape = image.shape
|
||||
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
||||
|
||||
y = model.predict(image)[0]
|
||||
|
||||
result_dict = {}
|
||||
|
||||
for i, tag in enumerate(tags):
|
||||
result_dict[tag] = y[i]
|
||||
|
||||
unsorted_tags_in_theshold = []
|
||||
result_tags_print = []
|
||||
for tag in tags:
|
||||
if result_dict[tag] >= threshold:
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||
|
||||
# sort tags
|
||||
result_tags_out = []
|
||||
sort_ndx = 0
|
||||
if alpha_sort:
|
||||
sort_ndx = 1
|
||||
probability_dict[tag] = probability
|
||||
|
||||
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||
for weight, tag in unsorted_tags_in_theshold:
|
||||
tag_outformat = tag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if include_ranks:
|
||||
tag_outformat = f"({tag_outformat}:{weight:.3f})"
|
||||
if alpha_sort:
|
||||
tags = sorted(probability_dict)
|
||||
else:
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
|
||||
result_tags_out.append(tag_outformat)
|
||||
res = []
|
||||
|
||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||
for tag in tags:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if include_ranks:
|
||||
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
||||
|
||||
return ', '.join(result_tags_out)
|
||||
res.append(tag_outformat)
|
||||
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
model = DeepDanbooru()
|
||||
|
|
|
|||
676
modules/deepbooru_model.py
Normal file
676
modules/deepbooru_model.py
Normal file
|
|
@ -0,0 +1,676 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
class DeepDanbooruModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeepDanbooruModel, self).__init__()
|
||||
|
||||
self.tags = []
|
||||
|
||||
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
||||
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
||||
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
||||
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
||||
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
||||
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
||||
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
||||
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
||||
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
||||
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
||||
|
||||
def forward(self, *inputs):
|
||||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
t_363 = self.n_Conv_1(t_362)
|
||||
t_364 = self.n_Conv_2(t_362)
|
||||
t_365 = F.relu(t_364)
|
||||
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
||||
t_366 = self.n_Conv_3(t_365_padded)
|
||||
t_367 = F.relu(t_366)
|
||||
t_368 = self.n_Conv_4(t_367)
|
||||
t_369 = torch.add(t_368, t_363)
|
||||
t_370 = F.relu(t_369)
|
||||
t_371 = self.n_Conv_5(t_370)
|
||||
t_372 = F.relu(t_371)
|
||||
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
||||
t_373 = self.n_Conv_6(t_372_padded)
|
||||
t_374 = F.relu(t_373)
|
||||
t_375 = self.n_Conv_7(t_374)
|
||||
t_376 = torch.add(t_375, t_370)
|
||||
t_377 = F.relu(t_376)
|
||||
t_378 = self.n_Conv_8(t_377)
|
||||
t_379 = F.relu(t_378)
|
||||
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
||||
t_380 = self.n_Conv_9(t_379_padded)
|
||||
t_381 = F.relu(t_380)
|
||||
t_382 = self.n_Conv_10(t_381)
|
||||
t_383 = torch.add(t_382, t_377)
|
||||
t_384 = F.relu(t_383)
|
||||
t_385 = self.n_Conv_11(t_384)
|
||||
t_386 = self.n_Conv_12(t_384)
|
||||
t_387 = F.relu(t_386)
|
||||
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
||||
t_388 = self.n_Conv_13(t_387_padded)
|
||||
t_389 = F.relu(t_388)
|
||||
t_390 = self.n_Conv_14(t_389)
|
||||
t_391 = torch.add(t_390, t_385)
|
||||
t_392 = F.relu(t_391)
|
||||
t_393 = self.n_Conv_15(t_392)
|
||||
t_394 = F.relu(t_393)
|
||||
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
||||
t_395 = self.n_Conv_16(t_394_padded)
|
||||
t_396 = F.relu(t_395)
|
||||
t_397 = self.n_Conv_17(t_396)
|
||||
t_398 = torch.add(t_397, t_392)
|
||||
t_399 = F.relu(t_398)
|
||||
t_400 = self.n_Conv_18(t_399)
|
||||
t_401 = F.relu(t_400)
|
||||
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
||||
t_402 = self.n_Conv_19(t_401_padded)
|
||||
t_403 = F.relu(t_402)
|
||||
t_404 = self.n_Conv_20(t_403)
|
||||
t_405 = torch.add(t_404, t_399)
|
||||
t_406 = F.relu(t_405)
|
||||
t_407 = self.n_Conv_21(t_406)
|
||||
t_408 = F.relu(t_407)
|
||||
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
||||
t_409 = self.n_Conv_22(t_408_padded)
|
||||
t_410 = F.relu(t_409)
|
||||
t_411 = self.n_Conv_23(t_410)
|
||||
t_412 = torch.add(t_411, t_406)
|
||||
t_413 = F.relu(t_412)
|
||||
t_414 = self.n_Conv_24(t_413)
|
||||
t_415 = F.relu(t_414)
|
||||
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
||||
t_416 = self.n_Conv_25(t_415_padded)
|
||||
t_417 = F.relu(t_416)
|
||||
t_418 = self.n_Conv_26(t_417)
|
||||
t_419 = torch.add(t_418, t_413)
|
||||
t_420 = F.relu(t_419)
|
||||
t_421 = self.n_Conv_27(t_420)
|
||||
t_422 = F.relu(t_421)
|
||||
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
||||
t_423 = self.n_Conv_28(t_422_padded)
|
||||
t_424 = F.relu(t_423)
|
||||
t_425 = self.n_Conv_29(t_424)
|
||||
t_426 = torch.add(t_425, t_420)
|
||||
t_427 = F.relu(t_426)
|
||||
t_428 = self.n_Conv_30(t_427)
|
||||
t_429 = F.relu(t_428)
|
||||
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
||||
t_430 = self.n_Conv_31(t_429_padded)
|
||||
t_431 = F.relu(t_430)
|
||||
t_432 = self.n_Conv_32(t_431)
|
||||
t_433 = torch.add(t_432, t_427)
|
||||
t_434 = F.relu(t_433)
|
||||
t_435 = self.n_Conv_33(t_434)
|
||||
t_436 = F.relu(t_435)
|
||||
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
||||
t_437 = self.n_Conv_34(t_436_padded)
|
||||
t_438 = F.relu(t_437)
|
||||
t_439 = self.n_Conv_35(t_438)
|
||||
t_440 = torch.add(t_439, t_434)
|
||||
t_441 = F.relu(t_440)
|
||||
t_442 = self.n_Conv_36(t_441)
|
||||
t_443 = self.n_Conv_37(t_441)
|
||||
t_444 = F.relu(t_443)
|
||||
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
||||
t_445 = self.n_Conv_38(t_444_padded)
|
||||
t_446 = F.relu(t_445)
|
||||
t_447 = self.n_Conv_39(t_446)
|
||||
t_448 = torch.add(t_447, t_442)
|
||||
t_449 = F.relu(t_448)
|
||||
t_450 = self.n_Conv_40(t_449)
|
||||
t_451 = F.relu(t_450)
|
||||
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
||||
t_452 = self.n_Conv_41(t_451_padded)
|
||||
t_453 = F.relu(t_452)
|
||||
t_454 = self.n_Conv_42(t_453)
|
||||
t_455 = torch.add(t_454, t_449)
|
||||
t_456 = F.relu(t_455)
|
||||
t_457 = self.n_Conv_43(t_456)
|
||||
t_458 = F.relu(t_457)
|
||||
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
||||
t_459 = self.n_Conv_44(t_458_padded)
|
||||
t_460 = F.relu(t_459)
|
||||
t_461 = self.n_Conv_45(t_460)
|
||||
t_462 = torch.add(t_461, t_456)
|
||||
t_463 = F.relu(t_462)
|
||||
t_464 = self.n_Conv_46(t_463)
|
||||
t_465 = F.relu(t_464)
|
||||
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
||||
t_466 = self.n_Conv_47(t_465_padded)
|
||||
t_467 = F.relu(t_466)
|
||||
t_468 = self.n_Conv_48(t_467)
|
||||
t_469 = torch.add(t_468, t_463)
|
||||
t_470 = F.relu(t_469)
|
||||
t_471 = self.n_Conv_49(t_470)
|
||||
t_472 = F.relu(t_471)
|
||||
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
||||
t_473 = self.n_Conv_50(t_472_padded)
|
||||
t_474 = F.relu(t_473)
|
||||
t_475 = self.n_Conv_51(t_474)
|
||||
t_476 = torch.add(t_475, t_470)
|
||||
t_477 = F.relu(t_476)
|
||||
t_478 = self.n_Conv_52(t_477)
|
||||
t_479 = F.relu(t_478)
|
||||
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
||||
t_480 = self.n_Conv_53(t_479_padded)
|
||||
t_481 = F.relu(t_480)
|
||||
t_482 = self.n_Conv_54(t_481)
|
||||
t_483 = torch.add(t_482, t_477)
|
||||
t_484 = F.relu(t_483)
|
||||
t_485 = self.n_Conv_55(t_484)
|
||||
t_486 = F.relu(t_485)
|
||||
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
||||
t_487 = self.n_Conv_56(t_486_padded)
|
||||
t_488 = F.relu(t_487)
|
||||
t_489 = self.n_Conv_57(t_488)
|
||||
t_490 = torch.add(t_489, t_484)
|
||||
t_491 = F.relu(t_490)
|
||||
t_492 = self.n_Conv_58(t_491)
|
||||
t_493 = F.relu(t_492)
|
||||
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
||||
t_494 = self.n_Conv_59(t_493_padded)
|
||||
t_495 = F.relu(t_494)
|
||||
t_496 = self.n_Conv_60(t_495)
|
||||
t_497 = torch.add(t_496, t_491)
|
||||
t_498 = F.relu(t_497)
|
||||
t_499 = self.n_Conv_61(t_498)
|
||||
t_500 = F.relu(t_499)
|
||||
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
||||
t_501 = self.n_Conv_62(t_500_padded)
|
||||
t_502 = F.relu(t_501)
|
||||
t_503 = self.n_Conv_63(t_502)
|
||||
t_504 = torch.add(t_503, t_498)
|
||||
t_505 = F.relu(t_504)
|
||||
t_506 = self.n_Conv_64(t_505)
|
||||
t_507 = F.relu(t_506)
|
||||
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
||||
t_508 = self.n_Conv_65(t_507_padded)
|
||||
t_509 = F.relu(t_508)
|
||||
t_510 = self.n_Conv_66(t_509)
|
||||
t_511 = torch.add(t_510, t_505)
|
||||
t_512 = F.relu(t_511)
|
||||
t_513 = self.n_Conv_67(t_512)
|
||||
t_514 = F.relu(t_513)
|
||||
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
||||
t_515 = self.n_Conv_68(t_514_padded)
|
||||
t_516 = F.relu(t_515)
|
||||
t_517 = self.n_Conv_69(t_516)
|
||||
t_518 = torch.add(t_517, t_512)
|
||||
t_519 = F.relu(t_518)
|
||||
t_520 = self.n_Conv_70(t_519)
|
||||
t_521 = F.relu(t_520)
|
||||
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
||||
t_522 = self.n_Conv_71(t_521_padded)
|
||||
t_523 = F.relu(t_522)
|
||||
t_524 = self.n_Conv_72(t_523)
|
||||
t_525 = torch.add(t_524, t_519)
|
||||
t_526 = F.relu(t_525)
|
||||
t_527 = self.n_Conv_73(t_526)
|
||||
t_528 = F.relu(t_527)
|
||||
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
||||
t_529 = self.n_Conv_74(t_528_padded)
|
||||
t_530 = F.relu(t_529)
|
||||
t_531 = self.n_Conv_75(t_530)
|
||||
t_532 = torch.add(t_531, t_526)
|
||||
t_533 = F.relu(t_532)
|
||||
t_534 = self.n_Conv_76(t_533)
|
||||
t_535 = F.relu(t_534)
|
||||
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
||||
t_536 = self.n_Conv_77(t_535_padded)
|
||||
t_537 = F.relu(t_536)
|
||||
t_538 = self.n_Conv_78(t_537)
|
||||
t_539 = torch.add(t_538, t_533)
|
||||
t_540 = F.relu(t_539)
|
||||
t_541 = self.n_Conv_79(t_540)
|
||||
t_542 = F.relu(t_541)
|
||||
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
||||
t_543 = self.n_Conv_80(t_542_padded)
|
||||
t_544 = F.relu(t_543)
|
||||
t_545 = self.n_Conv_81(t_544)
|
||||
t_546 = torch.add(t_545, t_540)
|
||||
t_547 = F.relu(t_546)
|
||||
t_548 = self.n_Conv_82(t_547)
|
||||
t_549 = F.relu(t_548)
|
||||
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
||||
t_550 = self.n_Conv_83(t_549_padded)
|
||||
t_551 = F.relu(t_550)
|
||||
t_552 = self.n_Conv_84(t_551)
|
||||
t_553 = torch.add(t_552, t_547)
|
||||
t_554 = F.relu(t_553)
|
||||
t_555 = self.n_Conv_85(t_554)
|
||||
t_556 = F.relu(t_555)
|
||||
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
||||
t_557 = self.n_Conv_86(t_556_padded)
|
||||
t_558 = F.relu(t_557)
|
||||
t_559 = self.n_Conv_87(t_558)
|
||||
t_560 = torch.add(t_559, t_554)
|
||||
t_561 = F.relu(t_560)
|
||||
t_562 = self.n_Conv_88(t_561)
|
||||
t_563 = F.relu(t_562)
|
||||
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
||||
t_564 = self.n_Conv_89(t_563_padded)
|
||||
t_565 = F.relu(t_564)
|
||||
t_566 = self.n_Conv_90(t_565)
|
||||
t_567 = torch.add(t_566, t_561)
|
||||
t_568 = F.relu(t_567)
|
||||
t_569 = self.n_Conv_91(t_568)
|
||||
t_570 = F.relu(t_569)
|
||||
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
||||
t_571 = self.n_Conv_92(t_570_padded)
|
||||
t_572 = F.relu(t_571)
|
||||
t_573 = self.n_Conv_93(t_572)
|
||||
t_574 = torch.add(t_573, t_568)
|
||||
t_575 = F.relu(t_574)
|
||||
t_576 = self.n_Conv_94(t_575)
|
||||
t_577 = F.relu(t_576)
|
||||
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
||||
t_578 = self.n_Conv_95(t_577_padded)
|
||||
t_579 = F.relu(t_578)
|
||||
t_580 = self.n_Conv_96(t_579)
|
||||
t_581 = torch.add(t_580, t_575)
|
||||
t_582 = F.relu(t_581)
|
||||
t_583 = self.n_Conv_97(t_582)
|
||||
t_584 = F.relu(t_583)
|
||||
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
||||
t_585 = self.n_Conv_98(t_584_padded)
|
||||
t_586 = F.relu(t_585)
|
||||
t_587 = self.n_Conv_99(t_586)
|
||||
t_588 = self.n_Conv_100(t_582)
|
||||
t_589 = torch.add(t_587, t_588)
|
||||
t_590 = F.relu(t_589)
|
||||
t_591 = self.n_Conv_101(t_590)
|
||||
t_592 = F.relu(t_591)
|
||||
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
||||
t_593 = self.n_Conv_102(t_592_padded)
|
||||
t_594 = F.relu(t_593)
|
||||
t_595 = self.n_Conv_103(t_594)
|
||||
t_596 = torch.add(t_595, t_590)
|
||||
t_597 = F.relu(t_596)
|
||||
t_598 = self.n_Conv_104(t_597)
|
||||
t_599 = F.relu(t_598)
|
||||
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
||||
t_600 = self.n_Conv_105(t_599_padded)
|
||||
t_601 = F.relu(t_600)
|
||||
t_602 = self.n_Conv_106(t_601)
|
||||
t_603 = torch.add(t_602, t_597)
|
||||
t_604 = F.relu(t_603)
|
||||
t_605 = self.n_Conv_107(t_604)
|
||||
t_606 = F.relu(t_605)
|
||||
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
||||
t_607 = self.n_Conv_108(t_606_padded)
|
||||
t_608 = F.relu(t_607)
|
||||
t_609 = self.n_Conv_109(t_608)
|
||||
t_610 = torch.add(t_609, t_604)
|
||||
t_611 = F.relu(t_610)
|
||||
t_612 = self.n_Conv_110(t_611)
|
||||
t_613 = F.relu(t_612)
|
||||
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
||||
t_614 = self.n_Conv_111(t_613_padded)
|
||||
t_615 = F.relu(t_614)
|
||||
t_616 = self.n_Conv_112(t_615)
|
||||
t_617 = torch.add(t_616, t_611)
|
||||
t_618 = F.relu(t_617)
|
||||
t_619 = self.n_Conv_113(t_618)
|
||||
t_620 = F.relu(t_619)
|
||||
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
||||
t_621 = self.n_Conv_114(t_620_padded)
|
||||
t_622 = F.relu(t_621)
|
||||
t_623 = self.n_Conv_115(t_622)
|
||||
t_624 = torch.add(t_623, t_618)
|
||||
t_625 = F.relu(t_624)
|
||||
t_626 = self.n_Conv_116(t_625)
|
||||
t_627 = F.relu(t_626)
|
||||
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
||||
t_628 = self.n_Conv_117(t_627_padded)
|
||||
t_629 = F.relu(t_628)
|
||||
t_630 = self.n_Conv_118(t_629)
|
||||
t_631 = torch.add(t_630, t_625)
|
||||
t_632 = F.relu(t_631)
|
||||
t_633 = self.n_Conv_119(t_632)
|
||||
t_634 = F.relu(t_633)
|
||||
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
||||
t_635 = self.n_Conv_120(t_634_padded)
|
||||
t_636 = F.relu(t_635)
|
||||
t_637 = self.n_Conv_121(t_636)
|
||||
t_638 = torch.add(t_637, t_632)
|
||||
t_639 = F.relu(t_638)
|
||||
t_640 = self.n_Conv_122(t_639)
|
||||
t_641 = F.relu(t_640)
|
||||
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
||||
t_642 = self.n_Conv_123(t_641_padded)
|
||||
t_643 = F.relu(t_642)
|
||||
t_644 = self.n_Conv_124(t_643)
|
||||
t_645 = torch.add(t_644, t_639)
|
||||
t_646 = F.relu(t_645)
|
||||
t_647 = self.n_Conv_125(t_646)
|
||||
t_648 = F.relu(t_647)
|
||||
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
||||
t_649 = self.n_Conv_126(t_648_padded)
|
||||
t_650 = F.relu(t_649)
|
||||
t_651 = self.n_Conv_127(t_650)
|
||||
t_652 = torch.add(t_651, t_646)
|
||||
t_653 = F.relu(t_652)
|
||||
t_654 = self.n_Conv_128(t_653)
|
||||
t_655 = F.relu(t_654)
|
||||
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
||||
t_656 = self.n_Conv_129(t_655_padded)
|
||||
t_657 = F.relu(t_656)
|
||||
t_658 = self.n_Conv_130(t_657)
|
||||
t_659 = torch.add(t_658, t_653)
|
||||
t_660 = F.relu(t_659)
|
||||
t_661 = self.n_Conv_131(t_660)
|
||||
t_662 = F.relu(t_661)
|
||||
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
||||
t_663 = self.n_Conv_132(t_662_padded)
|
||||
t_664 = F.relu(t_663)
|
||||
t_665 = self.n_Conv_133(t_664)
|
||||
t_666 = torch.add(t_665, t_660)
|
||||
t_667 = F.relu(t_666)
|
||||
t_668 = self.n_Conv_134(t_667)
|
||||
t_669 = F.relu(t_668)
|
||||
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
||||
t_670 = self.n_Conv_135(t_669_padded)
|
||||
t_671 = F.relu(t_670)
|
||||
t_672 = self.n_Conv_136(t_671)
|
||||
t_673 = torch.add(t_672, t_667)
|
||||
t_674 = F.relu(t_673)
|
||||
t_675 = self.n_Conv_137(t_674)
|
||||
t_676 = F.relu(t_675)
|
||||
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
||||
t_677 = self.n_Conv_138(t_676_padded)
|
||||
t_678 = F.relu(t_677)
|
||||
t_679 = self.n_Conv_139(t_678)
|
||||
t_680 = torch.add(t_679, t_674)
|
||||
t_681 = F.relu(t_680)
|
||||
t_682 = self.n_Conv_140(t_681)
|
||||
t_683 = F.relu(t_682)
|
||||
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
||||
t_684 = self.n_Conv_141(t_683_padded)
|
||||
t_685 = F.relu(t_684)
|
||||
t_686 = self.n_Conv_142(t_685)
|
||||
t_687 = torch.add(t_686, t_681)
|
||||
t_688 = F.relu(t_687)
|
||||
t_689 = self.n_Conv_143(t_688)
|
||||
t_690 = F.relu(t_689)
|
||||
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
||||
t_691 = self.n_Conv_144(t_690_padded)
|
||||
t_692 = F.relu(t_691)
|
||||
t_693 = self.n_Conv_145(t_692)
|
||||
t_694 = torch.add(t_693, t_688)
|
||||
t_695 = F.relu(t_694)
|
||||
t_696 = self.n_Conv_146(t_695)
|
||||
t_697 = F.relu(t_696)
|
||||
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
||||
t_698 = self.n_Conv_147(t_697_padded)
|
||||
t_699 = F.relu(t_698)
|
||||
t_700 = self.n_Conv_148(t_699)
|
||||
t_701 = torch.add(t_700, t_695)
|
||||
t_702 = F.relu(t_701)
|
||||
t_703 = self.n_Conv_149(t_702)
|
||||
t_704 = F.relu(t_703)
|
||||
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
||||
t_705 = self.n_Conv_150(t_704_padded)
|
||||
t_706 = F.relu(t_705)
|
||||
t_707 = self.n_Conv_151(t_706)
|
||||
t_708 = torch.add(t_707, t_702)
|
||||
t_709 = F.relu(t_708)
|
||||
t_710 = self.n_Conv_152(t_709)
|
||||
t_711 = F.relu(t_710)
|
||||
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
||||
t_712 = self.n_Conv_153(t_711_padded)
|
||||
t_713 = F.relu(t_712)
|
||||
t_714 = self.n_Conv_154(t_713)
|
||||
t_715 = torch.add(t_714, t_709)
|
||||
t_716 = F.relu(t_715)
|
||||
t_717 = self.n_Conv_155(t_716)
|
||||
t_718 = F.relu(t_717)
|
||||
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
||||
t_719 = self.n_Conv_156(t_718_padded)
|
||||
t_720 = F.relu(t_719)
|
||||
t_721 = self.n_Conv_157(t_720)
|
||||
t_722 = torch.add(t_721, t_716)
|
||||
t_723 = F.relu(t_722)
|
||||
t_724 = self.n_Conv_158(t_723)
|
||||
t_725 = self.n_Conv_159(t_723)
|
||||
t_726 = F.relu(t_725)
|
||||
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
||||
t_727 = self.n_Conv_160(t_726_padded)
|
||||
t_728 = F.relu(t_727)
|
||||
t_729 = self.n_Conv_161(t_728)
|
||||
t_730 = torch.add(t_729, t_724)
|
||||
t_731 = F.relu(t_730)
|
||||
t_732 = self.n_Conv_162(t_731)
|
||||
t_733 = F.relu(t_732)
|
||||
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
||||
t_734 = self.n_Conv_163(t_733_padded)
|
||||
t_735 = F.relu(t_734)
|
||||
t_736 = self.n_Conv_164(t_735)
|
||||
t_737 = torch.add(t_736, t_731)
|
||||
t_738 = F.relu(t_737)
|
||||
t_739 = self.n_Conv_165(t_738)
|
||||
t_740 = F.relu(t_739)
|
||||
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
||||
t_741 = self.n_Conv_166(t_740_padded)
|
||||
t_742 = F.relu(t_741)
|
||||
t_743 = self.n_Conv_167(t_742)
|
||||
t_744 = torch.add(t_743, t_738)
|
||||
t_745 = F.relu(t_744)
|
||||
t_746 = self.n_Conv_168(t_745)
|
||||
t_747 = self.n_Conv_169(t_745)
|
||||
t_748 = F.relu(t_747)
|
||||
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
||||
t_749 = self.n_Conv_170(t_748_padded)
|
||||
t_750 = F.relu(t_749)
|
||||
t_751 = self.n_Conv_171(t_750)
|
||||
t_752 = torch.add(t_751, t_746)
|
||||
t_753 = F.relu(t_752)
|
||||
t_754 = self.n_Conv_172(t_753)
|
||||
t_755 = F.relu(t_754)
|
||||
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
||||
t_756 = self.n_Conv_173(t_755_padded)
|
||||
t_757 = F.relu(t_756)
|
||||
t_758 = self.n_Conv_174(t_757)
|
||||
t_759 = torch.add(t_758, t_753)
|
||||
t_760 = F.relu(t_759)
|
||||
t_761 = self.n_Conv_175(t_760)
|
||||
t_762 = F.relu(t_761)
|
||||
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
||||
t_763 = self.n_Conv_176(t_762_padded)
|
||||
t_764 = F.relu(t_763)
|
||||
t_765 = self.n_Conv_177(t_764)
|
||||
t_766 = torch.add(t_765, t_760)
|
||||
t_767 = F.relu(t_766)
|
||||
t_768 = self.n_Conv_178(t_767)
|
||||
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
||||
t_770 = torch.squeeze(t_769, 3)
|
||||
t_770 = torch.squeeze(t_770, 2)
|
||||
t_771 = torch.sigmoid(t_770)
|
||||
return t_771
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
self.tags = state_dict.get('tags', [])
|
||||
|
||||
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
||||
|
||||
|
|
@ -2,30 +2,43 @@ import sys, os, shlex
|
|||
import contextlib
|
||||
import torch
|
||||
from modules import errors
|
||||
from packaging import version
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def has_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
if name in args[x]: return args[x+1]
|
||||
if name in args[x]:
|
||||
return args[x + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
|
||||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
from modules import shared
|
||||
return torch.device(get_cuda_device_string())
|
||||
|
||||
device_id = shared.cmd_opts.device_id
|
||||
|
||||
if device_id is not None:
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
return torch.device(cuda_device)
|
||||
else:
|
||||
return torch.device("cuda")
|
||||
|
||||
if has_mps:
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
|
||||
return cpu
|
||||
|
|
@ -33,8 +46,9 @@ def get_optimal_device():
|
|||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
|
|
@ -45,29 +59,22 @@ def enable_tf32():
|
|||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
def randn(seed, shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
generator.manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
def randn(seed, shape):
|
||||
torch.manual_seed(seed)
|
||||
if device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
|
|
@ -82,6 +89,27 @@ def autocast(disable=False):
|
|||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
|
||||
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
if self.device.type != 'mps' and \
|
||||
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
|
||||
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
|
||||
self = self.contiguous()
|
||||
return orig_tensor_to(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
orig_layer_norm = torch.nn.functional.layer_norm
|
||||
def layer_norm_fix(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
|
||||
args = list(args)
|
||||
args[0] = args[0].contiguous()
|
||||
return orig_layer_norm(*args, **kwargs)
|
||||
|
||||
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
|
||||
torch.Tensor.to = tensor_to_fix
|
||||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ def upscale_without_tiling(model, img):
|
|||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import git
|
|||
|
||||
from modules import paths, shared
|
||||
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
|
||||
|
|
@ -34,8 +33,11 @@ class Extension:
|
|||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
except Exception:
|
||||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
|
@ -63,9 +65,12 @@ class Extension:
|
|||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def pull(self):
|
||||
def fetch_and_reset_hard(self):
|
||||
repo = git.Repo(self.path)
|
||||
repo.remotes.origin.pull()
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch('--all')
|
||||
repo.git.reset('--hard', 'origin')
|
||||
|
||||
|
||||
def list_extensions():
|
||||
|
|
@ -81,3 +86,4 @@ def list_extensions():
|
|||
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||
extensions.append(extension)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
|
@ -12,7 +14,7 @@ from typing import Callable, List, OrderedDict, Tuple
|
|||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models
|
||||
from modules import processing, shared, images, devices, sd_models, sd_samplers
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
|
|
@ -20,7 +22,7 @@ import modules.codeformer_model
|
|||
import piexif
|
||||
import piexif.helper
|
||||
import gradio as gr
|
||||
|
||||
import safetensors.torch
|
||||
|
||||
class LruCache(OrderedDict):
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -136,12 +138,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
|
||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
blended_result: Image.Image = None
|
||||
image_hash: str = hash(np.array(image.getdata()).tobytes())
|
||||
for upscaler in params:
|
||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
|
||||
cache_key = LruCache.Key(image_hash=image_hash,
|
||||
info_hash=hash(info),
|
||||
args_hash=hash((upscale_args, upscale_first)))
|
||||
args_hash=hash(upscale_args))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
|
|
@ -212,25 +215,8 @@ def run_pnginfo(image):
|
|||
if image is None:
|
||||
return '', '', ''
|
||||
|
||||
items = image.info
|
||||
geninfo = ''
|
||||
|
||||
if "exif" in image.info:
|
||||
exif = piexif.load(image.info["exif"])
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
items.pop(field, None)
|
||||
|
||||
geninfo = items.get('parameters', geninfo)
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
|
||||
info = ''
|
||||
for key, text in items.items():
|
||||
|
|
@ -248,7 +234,7 @@ def run_pnginfo(image):
|
|||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
|
|
@ -263,19 +249,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
|||
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
|
||||
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
|
||||
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
if teritary_model_info is not None:
|
||||
print(f"Loading {teritary_model_info.filename}...")
|
||||
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
|
||||
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
|
||||
theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu')
|
||||
else:
|
||||
teritary_model = None
|
||||
theta_2 = None
|
||||
|
||||
theta_funcs = {
|
||||
|
|
@ -294,7 +276,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
|||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
del theta_2, teritary_model
|
||||
del theta_2
|
||||
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
|
|
@ -313,12 +295,17 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
|||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
|
||||
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
print(f"Saving to {output_modelname}...")
|
||||
torch.save(primary_model, output_modelname)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(theta_0, output_modelname)
|
||||
|
||||
sd_models.list_models()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import base64
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared
|
||||
|
|
@ -35,9 +37,8 @@ def quote(text):
|
|||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||
normfn = os.path.normpath(filename)
|
||||
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
|
|
@ -73,7 +74,9 @@ def integrate_settings_paste_fields(component_dict):
|
|||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'inpainting_mask_weight': 'Conditional mask weight',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
'eta_noise_seed_delta': 'ENSD',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
|
|
@ -181,6 +184,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
else:
|
||||
res[k] = v
|
||||
|
||||
# Missing CLIP skip means it was set to 1 (the default)
|
||||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,9 @@ def gfpgann():
|
|||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
|
|
@ -22,6 +22,8 @@ from collections import defaultdict, deque
|
|||
from statistics import stdev, mean
|
||||
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
|
|
@ -36,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
|
|
@ -142,6 +144,8 @@ class Hypernetwork:
|
|||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
|
||||
self.optimizer_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
|
|
@ -150,19 +154,32 @@ class Hypernetwork:
|
|||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
)
|
||||
self.eval_mode()
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train_mode(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.train()
|
||||
res += layer.trainables()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
return res
|
||||
def eval_mode(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def save(self, filename):
|
||||
state_dict = {}
|
||||
optimizer_saved_dict = {}
|
||||
|
||||
for k, v in self.layers.items():
|
||||
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||
|
|
@ -178,8 +195,15 @@ class Hypernetwork:
|
|||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
state_dict['activate_output'] = self.activate_output
|
||||
state_dict['last_layer_dropout'] = self.last_layer_dropout
|
||||
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
||||
torch.save(state_dict, filename)
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def load(self, filename):
|
||||
self.filename = filename
|
||||
|
|
@ -202,6 +226,18 @@ class Hypernetwork:
|
|||
print(f"Activate last layer is set to {self.activate_output}")
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
|
|
@ -219,11 +255,11 @@ class Hypernetwork:
|
|||
|
||||
def list_hypernetworks(path):
|
||||
res = {}
|
||||
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name] = filename
|
||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
||||
return res
|
||||
|
||||
|
||||
|
|
@ -343,13 +379,13 @@ def report_statistics(loss_info:dict):
|
|||
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
|
|
@ -358,6 +394,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||
|
|
@ -378,17 +415,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step >= steps:
|
||||
initial_step = hypernetwork.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
|
|
@ -396,19 +439,41 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
size = len(ds.indexes)
|
||||
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
losses = torch.zeros((size,))
|
||||
previous_mean_losses = [0]
|
||||
previous_mean_loss = 0
|
||||
print("Mean loss of {} elements".format(size))
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
hypernetwork.train_mode()
|
||||
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
else:
|
||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = 'AdamW'
|
||||
|
||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
# size = len(ds.indexes)
|
||||
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
# losses = torch.zeros((size,))
|
||||
# previous_mean_losses = [0]
|
||||
# previous_mean_loss = 0
|
||||
# print("Mean loss of {} elements".format(size))
|
||||
|
||||
steps_without_grad = 0
|
||||
|
||||
|
|
@ -416,124 +481,145 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
if len(loss_dict) > 0:
|
||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
||||
previous_mean_loss = mean(previous_mean_losses)
|
||||
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
else:
|
||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
del x
|
||||
del c
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
del x
|
||||
del c
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||
# scaler.unscale_(optimizer)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
pbar.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_step = _loss_step
|
||||
_loss_step = 0
|
||||
|
||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||
for entry in entries:
|
||||
loss_dict[entry.filename].append(loss.item())
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
weights[0].grad = None
|
||||
loss.backward()
|
||||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
if weights[0].grad is None:
|
||||
steps_without_grad += 1
|
||||
else:
|
||||
steps_without_grad = 0
|
||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
optimizer.step()
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
steps_done = hypernetwork.step + 1
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
hypernetwork.eval_mode()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
|
||||
if len(previous_mean_losses) > 1:
|
||||
std = stdev(previous_mean_losses)
|
||||
else:
|
||||
std = 0
|
||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||
pbar.set_description(dataset_loss_info)
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
"loss": f"{previous_mean_loss:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
preview_text = p.prompt
|
||||
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
optimizer.zero_grad()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
hypernetwork.train_mode()
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_index = preview_sampler_index
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.steps = 20
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images)>0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {previous_mean_loss:.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {steps_done}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
||||
report_statistics(loss_dict)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval_mode()
|
||||
#report_statistics(loss_dict)
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
|
||||
del optimizer
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
|
||||
return hypernetwork, filename
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
|
|||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import piexif.helper
|
|||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
from fonts.ttf import Roboto
|
||||
import string
|
||||
import json
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
|
@ -303,8 +304,9 @@ class FilenameGenerator:
|
|||
'width': lambda self: self.image.width,
|
||||
'height': lambda self: self.image.height,
|
||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
|
||||
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
|
||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
|
|
@ -524,6 +526,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
else:
|
||||
image.save(fullfn, quality=opts.jpeg_quality)
|
||||
|
||||
image.already_saved_as = fullfn
|
||||
|
||||
target_side_length = 4000
|
||||
oversize = image.width > target_side_length or image.height > target_side_length
|
||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
|
||||
|
|
@ -550,10 +554,45 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
def read_info_from_image(image):
|
||||
items = image.info or {}
|
||||
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
if "exif" in items:
|
||||
exif = piexif.load(items["exif"])
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
items.pop(field, None)
|
||||
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
try:
|
||||
json_info = json.loads(items["Comment"])
|
||||
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
||||
|
||||
geninfo = f"""{items["Description"]}
|
||||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
|
||||
def image_data(data):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
textinfo = image.text["parameters"]
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
return textinfo, None
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ import sys
|
|||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageChops
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
|
||||
|
||||
from modules import devices
|
||||
from modules import devices, sd_samplers
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
|
|
@ -40,7 +40,7 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
|
||||
img = Image.open(image)
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
|
|
@ -59,18 +59,30 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_batch = mode == 2
|
||||
|
||||
if is_inpaint:
|
||||
# Drawn mask
|
||||
if mask_mode == 0:
|
||||
image = init_img_with_mask['image']
|
||||
mask = init_img_with_mask['mask']
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert('RGB')
|
||||
image = init_img_with_mask
|
||||
is_mask_sketch = isinstance(image, dict)
|
||||
is_mask_paint = not is_mask_sketch
|
||||
if is_mask_sketch:
|
||||
# Sketch: mask iff. not transparent
|
||||
image, mask = image["image"], image["mask"]
|
||||
pred = np.array(mask)[..., -1] > 0
|
||||
else:
|
||||
# Color-sketch: mask iff. painted over
|
||||
orig = init_img_with_mask_orig or image
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
if is_mask_paint:
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
# Uploaded mask
|
||||
else:
|
||||
image = init_img_inpaint
|
||||
|
|
@ -82,7 +94,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
|
|
@ -99,7 +111,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
seed_resize_from_h=seed_resize_from_h,
|
||||
seed_resize_from_w=seed_resize_from_w,
|
||||
seed_enable_extras=seed_enable_extras,
|
||||
sampler_index=sampler_index,
|
||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=steps,
|
||||
|
|
|
|||
|
|
@ -148,8 +148,7 @@ class InterrogateModels:
|
|||
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
|
||||
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
with torch.no_grad(), devices.autocast():
|
||||
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
|
|
|||
|
|
@ -101,8 +101,8 @@ class LDSR:
|
|||
down_sample_rate = target_scale / 4
|
||||
wd = width_og * down_sample_rate
|
||||
hd = height_og * down_sample_rate
|
||||
width_downsampled_pre = int(wd)
|
||||
height_downsampled_pre = int(hd)
|
||||
width_downsampled_pre = int(np.ceil(wd))
|
||||
height_downsampled_pre = int(np.ceil(hd))
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
|
|
@ -110,7 +110,12 @@ class LDSR:
|
|||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||
|
||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||
|
||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
|
|
@ -120,6 +125,9 @@ class LDSR:
|
|||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
# remove padding
|
||||
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
|
||||
localizations = {}
|
||||
|
||||
|
||||
|
|
@ -16,6 +17,11 @@ def list_localizations(dirname):
|
|||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
|
||||
from modules import scripts
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
localizations[fn] = file.path
|
||||
|
||||
|
||||
def localization_js(current_localization_name):
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
|
|
|
|||
|
|
@ -51,6 +51,10 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||
|
||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||
|
|
@ -65,6 +69,10 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
||||
del sd_model.cond_stage_model.transformer
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ def cleanup_models():
|
|||
src_path = models_path
|
||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||
move_files(src_path, dest_path, ".ckpt")
|
||||
move_files(src_path, dest_path, ".safetensors")
|
||||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,23 @@
|
|||
from pyngrok import ngrok, conf, exception
|
||||
|
||||
|
||||
def connect(token, port, region):
|
||||
account = None
|
||||
if token == None:
|
||||
token = 'None'
|
||||
else:
|
||||
if ':' in token:
|
||||
# token = authtoken:username:password
|
||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
||||
token = token.split(':')[0]
|
||||
|
||||
config = conf.PyngrokConfig(
|
||||
auth_token=token, region=region
|
||||
)
|
||||
try:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config).public_url
|
||||
if account == None:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||
else:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
|
||||
except exception.PyngrokNgrokError:
|
||||
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
|
|||
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
|
@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
|
|||
|
||||
return image
|
||||
|
||||
def get_correct_sampler(p):
|
||||
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
||||
return sd_samplers.samplers
|
||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||
return sd_samplers.samplers_for_img2img
|
||||
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
||||
return sd_samplers.samplers
|
||||
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
|
||||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
|
|
@ -91,7 +88,7 @@ class StableDiffusionProcessing():
|
|||
self.subseed_strength: float = subseed_strength
|
||||
self.seed_resize_from_h: int = seed_resize_from_h
|
||||
self.seed_resize_from_w: int = seed_resize_from_w
|
||||
self.sampler_index: int = sampler_index
|
||||
self.sampler_name: str = sampler_name
|
||||
self.batch_size: int = batch_size
|
||||
self.n_iter: int = n_iter
|
||||
self.steps: int = steps
|
||||
|
|
@ -116,6 +113,7 @@ class StableDiffusionProcessing():
|
|||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
self.s_noise = s_noise or opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.is_using_inpainting_conditioning = False
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
|
|
@ -126,6 +124,7 @@ class StableDiffusionProcessing():
|
|||
self.scripts = None
|
||||
self.script_args = None
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
|
||||
|
|
@ -136,6 +135,8 @@ class StableDiffusionProcessing():
|
|||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1)
|
||||
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
|
|
@ -154,6 +155,8 @@ class StableDiffusionProcessing():
|
|||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
if image_mask is not None:
|
||||
if torch.is_tensor(image_mask):
|
||||
|
|
@ -200,7 +203,7 @@ class StableDiffusionProcessing():
|
|||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
self.images = images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
|
|
@ -210,8 +213,7 @@ class Processed:
|
|||
self.info = info
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_index = p.sampler_index
|
||||
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
||||
self.sampler_name = p.sampler_name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
|
|
@ -238,17 +240,20 @@ class Processed:
|
|||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
||||
|
||||
self.all_prompts = all_prompts or [self.prompt]
|
||||
self.all_seeds = all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or [self.subseed]
|
||||
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
||||
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||
self.infotexts = infotexts or [info]
|
||||
|
||||
def js(self):
|
||||
obj = {
|
||||
"prompt": self.prompt,
|
||||
"prompt": self.all_prompts[0],
|
||||
"all_prompts": self.all_prompts,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"negative_prompt": self.all_negative_prompts[0],
|
||||
"all_negative_prompts": self.all_negative_prompts,
|
||||
"seed": self.seed,
|
||||
"all_seeds": self.all_seeds,
|
||||
"subseed": self.subseed,
|
||||
|
|
@ -256,8 +261,7 @@ class Processed:
|
|||
"subseed_strength": self.subseed_strength,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"sampler_index": self.sampler_index,
|
||||
"sampler": self.sampler,
|
||||
"sampler_name": self.sampler_name,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"steps": self.steps,
|
||||
"batch_size": self.batch_size,
|
||||
|
|
@ -273,6 +277,7 @@ class Processed:
|
|||
"styles": self.styles,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"clip_skip": self.clip_skip,
|
||||
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
||||
}
|
||||
|
||||
return json.dumps(obj)
|
||||
|
|
@ -384,7 +389,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
||||
"Sampler": p.sampler_name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
|
|
@ -399,6 +404,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
|
|
@ -408,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
|
@ -418,13 +424,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
|
||||
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
finally: # restore opts to original state
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -437,10 +445,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
else:
|
||||
assert p.prompt is not None
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
seed = get_fixed_seed(p.seed)
|
||||
|
|
@ -451,12 +455,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
comments = {}
|
||||
|
||||
shared.prompt_styles.apply_styles(p)
|
||||
|
||||
if type(p.prompt) == list:
|
||||
p.all_prompts = p.prompt
|
||||
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
||||
else:
|
||||
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
||||
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
||||
|
||||
if type(p.negative_prompt) == list:
|
||||
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
||||
else:
|
||||
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
||||
|
||||
if type(seed) == list:
|
||||
p.all_seeds = seed
|
||||
|
|
@ -471,6 +478,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
|
|
@ -495,14 +506,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
break
|
||||
|
||||
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
with devices.autocast():
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
|
||||
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
|
|
@ -515,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
|
||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
del samples_ddim
|
||||
|
|
@ -588,7 +603,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
|
@ -642,7 +657,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
|
@ -665,17 +680,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
||||
|
||||
if opts.use_scale_latent_for_hires_fix:
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
|
||||
|
||||
# Avoid making the inpainting conditioning unless necessary as
|
||||
# this does need some extra compute to decode / encode the image again.
|
||||
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
||||
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
||||
else:
|
||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
else:
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
|
@ -703,7 +718,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
|
||||
shared.state.nextjob()
|
||||
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
||||
|
|
@ -727,7 +742,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
self.denoising_strength: float = denoising_strength
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
#self.image_unblurred_mask = None
|
||||
self.latent_mask = None
|
||||
self.mask_for_overlay = None
|
||||
self.mask_blur = mask_blur
|
||||
|
|
@ -740,39 +754,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
self.image_conditioning = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
crop_region = None
|
||||
|
||||
if self.image_mask is not None:
|
||||
self.image_mask = self.image_mask.convert('L')
|
||||
image_mask = self.image_mask
|
||||
|
||||
if image_mask is not None:
|
||||
image_mask = image_mask.convert('L')
|
||||
|
||||
if self.inpainting_mask_invert:
|
||||
self.image_mask = ImageOps.invert(self.image_mask)
|
||||
|
||||
#self.image_unblurred_mask = self.image_mask
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
if self.mask_blur > 0:
|
||||
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
|
||||
if self.inpaint_full_res:
|
||||
self.mask_for_overlay = self.image_mask
|
||||
mask = self.image_mask.convert('L')
|
||||
self.mask_for_overlay = image_mask
|
||||
mask = image_mask.convert('L')
|
||||
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
||||
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
mask = mask.crop(crop_region)
|
||||
self.image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||
image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
||||
else:
|
||||
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
||||
np_mask = np.array(self.image_mask)
|
||||
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
||||
np_mask = np.array(image_mask)
|
||||
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
||||
self.mask_for_overlay = Image.fromarray(np_mask)
|
||||
|
||||
self.overlay_images = []
|
||||
|
||||
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
||||
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
|
||||
|
||||
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
||||
if add_color_corrections:
|
||||
|
|
@ -784,7 +798,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
if crop_region is None:
|
||||
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
||||
|
||||
if self.image_mask is not None:
|
||||
if image_mask is not None:
|
||||
image_masked = Image.new('RGBa', (image.width, image.height))
|
||||
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
||||
|
||||
|
|
@ -794,7 +808,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
image = image.crop(crop_region)
|
||||
image = images.resize_image(2, image, self.width, self.height)
|
||||
|
||||
if self.image_mask is not None:
|
||||
if image_mask is not None:
|
||||
if self.inpainting_fill != 1:
|
||||
image = masking.fill(image, latent_mask)
|
||||
|
||||
|
|
@ -826,7 +840,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
|
||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||
|
||||
if self.image_mask is not None:
|
||||
if image_mask is not None:
|
||||
init_mask = latent_mask
|
||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
||||
|
|
@ -843,7 +857,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
|
||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
|
|
|||
|
|
@ -23,11 +23,18 @@ def encode(*args):
|
|||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
|
|
@ -52,32 +59,37 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if name in allowed_zip_names:
|
||||
continue
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename):
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
with z.open('archive/data.pkl') as file:
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
|
@ -85,16 +97,42 @@ def check_pt(filename):
|
|||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this functon is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename)
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
|
@ -45,26 +46,33 @@ class CFGDenoiserParams:
|
|||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callbacks_app_started = []
|
||||
callbacks_model_loaded = []
|
||||
callbacks_ui_tabs = []
|
||||
callbacks_ui_settings = []
|
||||
callbacks_before_image_saved = []
|
||||
callbacks_image_saved = []
|
||||
callbacks_cfg_denoiser = []
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
)
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
callbacks_model_loaded.clear()
|
||||
callbacks_ui_tabs.clear()
|
||||
callbacks_ui_settings.clear()
|
||||
callbacks_before_image_saved.clear()
|
||||
callbacks_image_saved.clear()
|
||||
callbacks_cfg_denoiser.clear()
|
||||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callbacks_app_started:
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
except Exception:
|
||||
|
|
@ -72,7 +80,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
|||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callbacks_model_loaded:
|
||||
for c in callback_map['callbacks_model_loaded']:
|
||||
try:
|
||||
c.callback(sd_model)
|
||||
except Exception:
|
||||
|
|
@ -81,8 +89,8 @@ def model_loaded_callback(sd_model):
|
|||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for c in callbacks_ui_tabs:
|
||||
|
||||
for c in callback_map['callbacks_ui_tabs']:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
except Exception:
|
||||
|
|
@ -91,8 +99,16 @@ def ui_tabs_callback():
|
|||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in callback_map['callbacks_ui_train_tabs']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callbacks_ui_settings:
|
||||
for c in callback_map['callbacks_ui_settings']:
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
|
|
@ -100,7 +116,7 @@ def ui_settings_callback():
|
|||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_before_image_saved:
|
||||
for c in callback_map['callbacks_before_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
|
|
@ -108,7 +124,7 @@ def before_image_saved_callback(params: ImageSaveParams):
|
|||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
for c in callback_map['callbacks_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
|
|
@ -116,30 +132,62 @@ def image_saved_callback(params: ImageSaveParams):
|
|||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in callbacks_cfg_denoiser:
|
||||
for c in callback_map['callbacks_cfg_denoiser']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'before_component_callback')
|
||||
|
||||
|
||||
def after_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_after_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
if filename == 'unknown file':
|
||||
return
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def remove_callbacks_for_function(callback_func):
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def on_app_started(callback):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callbacks_app_started, callback)
|
||||
add_callback(callback_map['callbacks_app_started'], callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
add_callback(callbacks_model_loaded, callback)
|
||||
add_callback(callback_map['callbacks_model_loaded'], callback)
|
||||
|
||||
|
||||
def on_ui_tabs(callback):
|
||||
|
|
@ -152,13 +200,20 @@ def on_ui_tabs(callback):
|
|||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
add_callback(callbacks_ui_tabs, callback)
|
||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
add_callback(callbacks_ui_settings, callback)
|
||||
add_callback(callback_map['callbacks_ui_settings'], callback)
|
||||
|
||||
|
||||
def on_before_image_saved(callback):
|
||||
|
|
@ -166,7 +221,7 @@ def on_before_image_saved(callback):
|
|||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callbacks_before_image_saved, callback)
|
||||
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
|
|
@ -174,7 +229,7 @@ def on_image_saved(callback):
|
|||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callbacks_image_saved, callback)
|
||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback):
|
||||
|
|
@ -182,5 +237,21 @@ def on_cfg_denoiser(callback):
|
|||
The callback is called with one argument:
|
||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callbacks_cfg_denoiser, callback)
|
||||
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
- component - gradio component that is about to be created.
|
||||
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
||||
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_component'], callback)
|
||||
|
||||
|
||||
def on_after_component(callback):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback)
|
||||
|
|
|
|||
34
modules/script_loading.py
Normal file
34
modules/script_loading.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def preload_extensions(extensions_dir, parser):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||
if not os.path.isfile(preload_script):
|
||||
continue
|
||||
|
||||
try:
|
||||
module = load_module(preload_script)
|
||||
if hasattr(module, 'preload'):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
|
@ -3,11 +3,10 @@ import sys
|
|||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
import modules.ui as ui
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks, extensions
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
|
@ -18,6 +17,9 @@ class Script:
|
|||
args_to = None
|
||||
alwayson = False
|
||||
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
group = None
|
||||
|
||||
|
|
@ -73,6 +75,19 @@ class Script:
|
|||
|
||||
pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
|
|
@ -81,6 +96,23 @@ class Script:
|
|||
|
||||
pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
You can return created components in the ui() function to add them to the list of arguments for your processing functions
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def describe(self):
|
||||
"""unused"""
|
||||
return ""
|
||||
|
|
@ -128,7 +160,7 @@ def list_files_with_name(filename):
|
|||
continue
|
||||
|
||||
path = os.path.join(dirpath, filename)
|
||||
if os.path.isfile(filename):
|
||||
if os.path.isfile(path):
|
||||
res.append(path)
|
||||
|
||||
return res
|
||||
|
|
@ -149,13 +181,7 @@ def load_scripts():
|
|||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
|
||||
with open(scriptfile.path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
from types import ModuleType
|
||||
compiled = compile(text, scriptfile.path, 'exec')
|
||||
module = ModuleType(scriptfile.filename)
|
||||
exec(compiled, module.__dict__)
|
||||
module = script_loading.load_module(scriptfile.path)
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
|
|
@ -189,12 +215,18 @@ class ScriptRunner:
|
|||
self.titles = []
|
||||
self.infotext_fields = []
|
||||
|
||||
def setup_ui(self, is_img2img):
|
||||
def initialize_scripts(self, is_img2img):
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
for script_class, path, basedir in scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
|
||||
visibility = script.show(is_img2img)
|
||||
visibility = script.show(script.is_img2img)
|
||||
|
||||
if visibility == AlwaysVisible:
|
||||
self.scripts.append(script)
|
||||
|
|
@ -205,6 +237,7 @@ class ScriptRunner:
|
|||
self.scripts.append(script)
|
||||
self.selectable_scripts.append(script)
|
||||
|
||||
def setup_ui(self):
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
inputs = [None]
|
||||
|
|
@ -214,7 +247,7 @@ class ScriptRunner:
|
|||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
|
||||
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
|
||||
if controls is None:
|
||||
return
|
||||
|
|
@ -296,6 +329,15 @@ class ScriptRunner:
|
|||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process_batch(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -305,33 +347,44 @@ class ScriptRunner:
|
|||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.before_component(component, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running before_component: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.after_component(component, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running after_component: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def reload_sources(self, cache):
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
with open(script.filename, "r", encoding="utf8") as file:
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
filename = script.filename
|
||||
text = file.read()
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
filename = script.filename
|
||||
|
||||
from types import ModuleType
|
||||
module = cache.get(filename, None)
|
||||
if module is None:
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
|
||||
module = cache.get(filename, None)
|
||||
if module is None:
|
||||
compiled = compile(text, filename, 'exec')
|
||||
module = ModuleType(script.filename)
|
||||
exec(compiled, module.__dict__)
|
||||
cache[filename] = module
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
self.scripts[si] = script_class()
|
||||
self.scripts[si].filename = filename
|
||||
self.scripts[si].args_from = args_from
|
||||
self.scripts[si].args_to = args_to
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
self.scripts[si] = script_class()
|
||||
self.scripts[si].filename = filename
|
||||
self.scripts[si].args_from = args_from
|
||||
self.scripts[si].args_to = args_to
|
||||
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_current: ScriptRunner = None
|
||||
|
||||
|
||||
def reload_script_body_only():
|
||||
|
|
@ -348,3 +401,22 @@ def reload_scripts():
|
|||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
|
||||
script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||
gr.components.IOComponent.__init__ = IOComponent_init
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = devices.mps_contiguous_to(img.unsqueeze(0), device)
|
||||
img = img.unsqueeze(0).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
|
|
|
|||
|
|
@ -8,17 +8,32 @@ from torch import einsum
|
|||
from torch.nn.functional import silu
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import opts, device, cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip
|
||||
|
||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
|
|
@ -47,16 +62,15 @@ def apply_optimizations():
|
|||
|
||||
|
||||
def undo_optimizations():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
def fix_checkpoint():
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
|
|
@ -68,14 +82,18 @@ class StableDiffusionModelHijack:
|
|||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
apply_optimizations()
|
||||
fix_checkpoint()
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
|
|
@ -87,15 +105,18 @@ class StableDiffusionModelHijack:
|
|||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
||||
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
self.circular_enabled = False
|
||||
self.clip = None
|
||||
|
||||
def apply_circular(self, enable):
|
||||
|
|
@ -112,262 +133,9 @@ class StableDiffusionModelHijack:
|
|||
|
||||
def tokenize(self, text):
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack: StableDiffusionModelHijack = hijack
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.token_mults = {}
|
||||
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
overflowing_words = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
|
|
@ -406,3 +174,19 @@ def add_circular_option_to_conv_2d():
|
|||
|
||||
|
||||
model_hijack = StableDiffusionModelHijack()
|
||||
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
"""
|
||||
Fix register buffer bug for Mac OS.
|
||||
"""
|
||||
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
|
|
|
|||
10
modules/sd_hijack_checkpoint.py
Normal file
10
modules/sd_hijack_checkpoint.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
301
modules/sd_hijack_clip.py
Normal file
301
modules/sd_hijack_clip.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack = hijack
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [self.id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [self.id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def process_text_old(self, texts):
|
||||
id_start = self.id_start
|
||||
id_end = self.id_end
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.tokenize(texts)
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.id_end] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||
|
||||
if self.id_end != self.id_pad:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
|
@ -199,8 +199,8 @@ def sample_plms(self,
|
|||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
|
|
@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
|
|
@ -321,11 +323,16 @@ def should_hijack_inpainting(checkpoint_info):
|
|||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
||||
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
# this file should be cleaned up later if weverything tuens out to work fine
|
||||
|
||||
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
# ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
|
|
|
|||
37
modules/sd_hijack_open_clip.py
Normal file
37
modules/sd_hijack_open_clip.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||
z = self.wrapped.encode_with_transformer(tokens)
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
|
@ -5,6 +5,7 @@ import gc
|
|||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
|
@ -45,7 +46,7 @@ def checkpoint_tiles():
|
|||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||
|
||||
def modeltitle(path, shorthash):
|
||||
abspath = os.path.abspath(path)
|
||||
|
|
@ -143,8 +144,8 @@ def transform_checkpoint_dict_key(k):
|
|||
|
||||
|
||||
def get_state_dict_from_checkpoint(pl_sd):
|
||||
if "state_dict" in pl_sd:
|
||||
pl_sd = pl_sd["state_dict"]
|
||||
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
||||
pl_sd.pop("state_dict", None)
|
||||
|
||||
sd = {}
|
||||
for k, v in pl_sd.items():
|
||||
|
|
@ -159,25 +160,41 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
return pl_sd
|
||||
|
||||
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
if print_global_state and "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
checkpoint_key = checkpoint_info
|
||||
|
||||
if checkpoint_key not in checkpoints_loaded:
|
||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
|
||||
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
del pl_sd
|
||||
sd = read_state_dict(checkpoint_file)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
|
@ -197,23 +214,16 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
|||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# if PR #4035 were to get merged, restore base VAE first before caching
|
||||
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
else:
|
||||
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
|
||||
vae_message = f" with {vae_name} VAE" if vae_name else ""
|
||||
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
|
||||
checkpoints_loaded.move_to_end(checkpoint_key)
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_key])
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
sd_vae.load_vae(model, vae_file)
|
||||
|
||||
|
||||
|
|
@ -244,6 +254,9 @@ def load_model(checkpoint_info=None):
|
|||
|
||||
do_inpainting_hijack()
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections import namedtuple
|
||||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
from math import floor
|
||||
import torch
|
||||
|
|
@ -6,6 +6,7 @@ import tqdm
|
|||
from PIL import Image
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import torchsde._brownian.brownian_interval
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, processing, images
|
||||
|
|
@ -18,17 +19,23 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
|
|
@ -42,16 +49,24 @@ all_samplers = [
|
|||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
|
||||
|
||||
def create_sampler_with_index(list_of_configs, index, model):
|
||||
config = list_of_configs[index]
|
||||
def create_sampler(name, model):
|
||||
if name is not None:
|
||||
config = all_samplers_map.get(name, None)
|
||||
else:
|
||||
config = all_samplers[0]
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
|
|
@ -64,6 +79,12 @@ def set_samplers():
|
|||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
|
||||
samplers_map.clear()
|
||||
for sampler in all_samplers:
|
||||
samplers_map[sampler.name.lower()] = sampler.name
|
||||
for alias in sampler.aliases:
|
||||
samplers_map[alias.lower()] = sampler.name
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
||||
|
|
@ -116,7 +137,8 @@ class InterruptedException(BaseException):
|
|||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
|
|
@ -207,7 +229,6 @@ class VanillaStableDiffusionSampler:
|
|||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
|
|
@ -216,7 +237,6 @@ class VanillaStableDiffusionSampler:
|
|||
|
||||
return num_steps
|
||||
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
|
|
@ -249,9 +269,10 @@ class VanillaStableDiffusionSampler:
|
|||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
|
|
@ -324,28 +345,55 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, kdiff_sampler):
|
||||
self.kdiff_sampler = kdiff_sampler
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.kdiff_sampler.randn_like
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
# MPS fix for randn in torchsde
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
else:
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.sampler_noise_index = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 1.0
|
||||
|
|
@ -378,26 +426,13 @@ class KDiffusionSampler:
|
|||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def randn_like(self, x):
|
||||
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
|
||||
|
||||
if noise is not None and x.shape == noise.shape:
|
||||
res = noise
|
||||
else:
|
||||
res = torch.randn_like(x)
|
||||
|
||||
self.sampler_noise_index += 1
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap.step = 0
|
||||
self.sampler_noise_index = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
if self.sampler_noises is not None:
|
||||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
|
|
|
|||
|
|
@ -83,47 +83,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
|||
return vae_list
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print("VAE provided as function argument doesn't exist")
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print("VAE provided as command line argument doesn't exist")
|
||||
# else, we load from settings
|
||||
def get_vae_from_settings(vae_file="auto"):
|
||||
# else, we load from settings, if not set to be default
|
||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||
# if saved VAE settings isn't recognized, fallback to auto
|
||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||
# if VAE selected but not found, fallback to auto
|
||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print("Selected VAE doesn't exist")
|
||||
print(f"Selected VAE doesn't exist: {vae_file}")
|
||||
return vae_file
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
||||
vae_file = "auto"
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
||||
# fallback to selector in settings, if vae selector not set to act as default fallback
|
||||
if not shared.opts.sd_vae_as_default:
|
||||
vae_file = get_vae_from_settings(vae_file)
|
||||
# vae-path cmd arg takes priority for auto
|
||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
print("Using VAE provided as command line argument")
|
||||
print(f"Using VAE provided as command line argument: {vae_file}")
|
||||
# if still not found, try look for ".vae.pt" beside model
|
||||
model_path = os.path.splitext(checkpoint_file)[0]
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.pt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print("Using VAE found beside selected model")
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.ckpt" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.ckpt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print("Using VAE found beside selected model")
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# No more fallbacks for auto
|
||||
if vae_file == "auto":
|
||||
vae_file = None
|
||||
|
|
@ -139,6 +146,7 @@ def load_vae(model, vae_file=None):
|
|||
# save_settings = False
|
||||
|
||||
if vae_file:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
|
|
@ -12,17 +11,18 @@ import tqdm
|
|||
import modules.artists
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers, sd_models, localization, sd_vae
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules import localization, sd_vae, extensions, script_loading
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
||||
demo = None
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
|
|
@ -44,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
|
|||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
|
|
@ -55,7 +56,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
|
|||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
|
|
@ -71,6 +72,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui
|
|||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
|
|
@ -80,13 +82,22 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
|
|||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
|
|
@ -99,7 +110,7 @@ restricted_opts = {
|
|||
"outdir_save",
|
||||
}
|
||||
|
||||
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
|
||||
|
|
@ -113,10 +124,12 @@ xformers_available = False
|
|||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
hypernetworks = {}
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
global hypernetworks
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
|
|
@ -146,6 +159,9 @@ class State:
|
|||
self.interrupted = True
|
||||
|
||||
def nextjob(self):
|
||||
if opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
|
@ -186,17 +202,22 @@ class State:
|
|||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
state = State()
|
||||
|
||||
|
|
@ -209,8 +230,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||
|
||||
face_restorers = []
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
|
|
@ -235,6 +254,21 @@ def options_section(section_identifier, options_dict):
|
|||
return options_dict
|
||||
|
||||
|
||||
def list_checkpoint_tiles():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.checkpoint_tiles()
|
||||
|
||||
|
||||
def refresh_checkpoints():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.list_models()
|
||||
|
||||
|
||||
def list_samplers():
|
||||
import modules.sd_samplers
|
||||
return modules.sd_samplers.all_samplers
|
||||
|
||||
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
options_templates = {}
|
||||
|
|
@ -263,6 +297,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
|
|
@ -287,7 +325,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
|
|
@ -309,6 +347,8 @@ options_templates.update(options_section(('system', "System"), {
|
|||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
|
|
@ -317,9 +357,10 @@ options_templates.update(options_section(('training', "Training"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
|
|
@ -331,7 +372,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
}))
|
||||
|
||||
|
|
@ -351,7 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
|
|
@ -368,7 +409,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
|
|
@ -398,7 +439,8 @@ class Options:
|
|||
if key in self.data or key in self.data_labels:
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
comp_args = opts.data_labels[key].component_args
|
||||
info = opts.data_labels.get(key, None)
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
|
|
@ -420,6 +462,23 @@ class Options:
|
|||
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
def set(self, key, value):
|
||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||
|
||||
oldval = self.data.get(key, None)
|
||||
if oldval == value:
|
||||
return False
|
||||
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if self.data_labels[key].onchange is not None:
|
||||
self.data_labels[key].onchange()
|
||||
|
||||
return True
|
||||
|
||||
def save(self, filename):
|
||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||
|
||||
|
|
|
|||
|
|
@ -65,17 +65,6 @@ class StyleDatabase:
|
|||
def apply_negative_styles_to_prompt(self, prompt, styles):
|
||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
||||
|
||||
def apply_styles(self, p: StableDiffusionProcessing) -> None:
|
||||
if isinstance(p.prompt, list):
|
||||
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
|
||||
else:
|
||||
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
|
||||
|
||||
if isinstance(p.negative_prompt, list):
|
||||
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
|
||||
else:
|
||||
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
||||
|
||||
def save_styles(self, path: str) -> None:
|
||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||
fd, temp_path = tempfile.mkstemp(".csv")
|
||||
|
|
|
|||
|
|
@ -13,10 +13,6 @@ from modules.swinir_model_arch import SwinIR as net
|
|||
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
|
|
@ -111,8 +107,8 @@ def upscale(
|
|||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
img = img.unsqueeze(0).to(devices.device_swinir)
|
||||
with torch.no_grad(), devices.autocast():
|
||||
_, _, h_old, w_old = img.size()
|
||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||
|
|
|
|||
|
|
@ -276,8 +276,8 @@ def poi_average(pois, settings):
|
|||
weight += poi.weight
|
||||
x += poi.x * poi.weight
|
||||
y += poi.y * poi.weight
|
||||
avg_x = round(x / weight)
|
||||
avg_y = round(y / weight)
|
||||
avg_x = round(weight and x / weight)
|
||||
avg_y = round(weight and y / weight)
|
||||
|
||||
return PointOfInterest(avg_x, avg_y)
|
||||
|
||||
|
|
@ -338,4 +338,4 @@ class Settings:
|
|||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
self.dnn_model_path = dnn_model_path
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
import random
|
||||
|
|
@ -11,25 +11,28 @@ import tqdm
|
|||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
self.latent_dist = latent_dist
|
||||
self.latent_sample = latent_sample
|
||||
self.cond = cond
|
||||
self.cond_text = cond_text
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
|
@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
|
|||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
if shared.state.interrupted:
|
||||
raise Exception("inturrupted")
|
||||
try:
|
||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
|
|
@ -71,53 +79,94 @@ class PersonalizedBase(Dataset):
|
|||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = None
|
||||
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
with devices.autocast():
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
latent_sampling_method = "once"
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "deterministic":
|
||||
# Works only for DiagonalGaussianDistribution
|
||||
latent_dist.std = 0
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||
|
||||
if include_cond:
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with devices.autocast():
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
|
||||
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = np.random.permutation(self.dataset_length)
|
||||
self.length = len(self.dataset)
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
tags = filename_text.split(',')
|
||||
if self.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
if self.shuffle_tags:
|
||||
random.shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
res = []
|
||||
entry = self.dataset[i]
|
||||
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
if self.latent_sampling_method == "random":
|
||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
return entry
|
||||
|
||||
for j in range(self.batch_size):
|
||||
position = i * self.batch_size + j
|
||||
if position % len(self.indexes) == 0:
|
||||
self.shuffle()
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
self.collate_fn = collate_wrapper
|
||||
|
||||
|
||||
index = self.indexes[position % len(self.indexes)]
|
||||
entry = self.dataset[index]
|
||||
class BatchLoader:
|
||||
def __init__(self, data):
|
||||
self.cond_text = [entry.cond_text for entry in data]
|
||||
self.cond = [entry.cond for entry in data]
|
||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
#self.emb_index = [entry.emb_index for entry in data]
|
||||
#print(self.latent_sample.device)
|
||||
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
def pin_memory(self):
|
||||
self.latent_sample = self.latent_sample.pin_memory()
|
||||
return self
|
||||
|
||||
res.append(entry)
|
||||
def collate_wrapper(batch):
|
||||
return BatchLoader(batch)
|
||||
|
||||
return res
|
||||
class BatchLoaderRandom(BatchLoader):
|
||||
def __init__(self, data):
|
||||
super().__init__(data)
|
||||
|
||||
def pin_memory(self):
|
||||
return self
|
||||
|
||||
def collate_wrapper_random(batch):
|
||||
return BatchLoaderRandom(batch)
|
||||
|
|
@ -6,12 +6,10 @@ import sys
|
|||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import shared, images
|
||||
from modules import shared, images, deepbooru
|
||||
from modules.paths import models_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
if cmd_opts.deepdanbooru:
|
||||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
|
|
@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
|
|||
shared.interrogator.load()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
db_opts = deepbooru.create_deepbooru_opts()
|
||||
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
||||
deepbooru.model.start()
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
|
||||
|
||||
|
|
@ -32,9 +28,87 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
|
|||
shared.interrogator.send_blip_to_ram()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.release_process()
|
||||
deepbooru.model.stop()
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
return os.listdir(dirname)
|
||||
|
||||
|
||||
class PreprocessParams:
|
||||
src = None
|
||||
dstdir = None
|
||||
subindex = 0
|
||||
flip = False
|
||||
process_caption = False
|
||||
process_caption_deepbooru = False
|
||||
preprocess_txt_action = None
|
||||
|
||||
|
||||
def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
|
||||
caption = ""
|
||||
|
||||
if params.process_caption:
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if params.process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.model.tag_multi(image)
|
||||
|
||||
filename_part = params.src
|
||||
filename_part = os.path.splitext(filename_part)[0]
|
||||
filename_part = os.path.basename(filename_part)
|
||||
|
||||
basename = f"{index:05}-{params.subindex}-{filename_part}"
|
||||
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
||||
|
||||
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
elif params.preprocess_txt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
|
||||
caption = caption.strip()
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
params.subindex += 1
|
||||
|
||||
|
||||
def save_pic(image, index, params, existing_caption=None):
|
||||
save_pic_with_caption(image, index, params, existing_caption=existing_caption)
|
||||
|
||||
if params.flip:
|
||||
save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
|
||||
|
||||
|
||||
def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
||||
if inverse_xy:
|
||||
from_w, from_h = image.height, image.width
|
||||
to_w, to_h = height, width
|
||||
else:
|
||||
from_w, from_h = image.width, image.height
|
||||
to_w, to_h = width, height
|
||||
h = from_h * to_w // from_w
|
||||
if inverse_xy:
|
||||
image = image.resize((h, to_w))
|
||||
else:
|
||||
image = image.resize((to_w, h))
|
||||
|
||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
||||
y_step = (h - to_h) / (split_count - 1)
|
||||
for i in range(split_count):
|
||||
y = int(y_step * i)
|
||||
if inverse_xy:
|
||||
splitted = image.crop((y, 0, y + to_h, to_w))
|
||||
else:
|
||||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
width = process_width
|
||||
|
|
@ -48,82 +122,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
|
||||
files = os.listdir(src)
|
||||
files = listfiles(src)
|
||||
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
def save_pic_with_caption(image, index, existing_caption=None):
|
||||
caption = ""
|
||||
|
||||
if process_caption:
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.get_tags_from_process(image)
|
||||
|
||||
filename_part = filename
|
||||
filename_part = os.path.splitext(filename_part)[0]
|
||||
filename_part = os.path.basename(filename_part)
|
||||
|
||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
|
||||
if preprocess_txt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
elif preprocess_txt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
elif preprocess_txt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
|
||||
caption = caption.strip()
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
subindex[0] += 1
|
||||
|
||||
def save_pic(image, index, existing_caption=None):
|
||||
save_pic_with_caption(image, index, existing_caption=existing_caption)
|
||||
|
||||
if process_flip:
|
||||
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
|
||||
|
||||
def split_pic(image, inverse_xy):
|
||||
if inverse_xy:
|
||||
from_w, from_h = image.height, image.width
|
||||
to_w, to_h = height, width
|
||||
else:
|
||||
from_w, from_h = image.width, image.height
|
||||
to_w, to_h = width, height
|
||||
h = from_h * to_w // from_w
|
||||
if inverse_xy:
|
||||
image = image.resize((h, to_w))
|
||||
else:
|
||||
image = image.resize((to_w, h))
|
||||
|
||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
||||
y_step = (h - to_h) / (split_count - 1)
|
||||
for i in range(split_count):
|
||||
y = int(y_step * i)
|
||||
if inverse_xy:
|
||||
splitted = image.crop((y, 0, y + to_h, to_w))
|
||||
else:
|
||||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
params = PreprocessParams()
|
||||
params.dstdir = dst
|
||||
params.flip = process_flip
|
||||
params.process_caption = process_caption
|
||||
params.process_caption_deepbooru = process_caption_deepbooru
|
||||
params.preprocess_txt_action = preprocess_txt_action
|
||||
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
subindex = [0]
|
||||
params.subindex = 0
|
||||
filename = os.path.join(src, imagefile)
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
params.src = filename
|
||||
|
||||
existing_caption = None
|
||||
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
|
||||
if os.path.exists(existing_caption_filename):
|
||||
|
|
@ -143,8 +163,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
process_default_resize = True
|
||||
|
||||
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
||||
for splitted in split_pic(img, inverse_xy):
|
||||
save_pic(splitted, index, existing_caption=existing_caption)
|
||||
for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
|
||||
save_pic(splitted, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_focal_crop and img.height != img.width:
|
||||
|
|
@ -165,11 +185,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
dnn_model_path = dnn_model_path,
|
||||
)
|
||||
for focal in autocrop.crop_image(img, autocrop_settings):
|
||||
save_pic(focal, index, existing_caption=existing_caption)
|
||||
save_pic(focal, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index, existing_caption=existing_caption)
|
||||
save_pic(img, index, params, existing_caption=existing_caption)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.nextjob()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import csv
|
|||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
|
@ -64,7 +64,8 @@ class EmbeddingDatabase:
|
|||
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
|
||||
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
|
||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||
|
||||
first_id = ids[0]
|
||||
if first_id not in self.ids_lookup:
|
||||
|
|
@ -155,13 +156,11 @@ class EmbeddingDatabase:
|
|||
|
||||
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||
|
||||
with devices.autocast():
|
||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||
|
||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
|
||||
for i in range(num_vectors_per_token):
|
||||
|
|
@ -184,7 +183,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
|
||||
if (step + 1) % shared.opts.training_write_csv_every != 0:
|
||||
if step % shared.opts.training_write_csv_every != 0:
|
||||
return
|
||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||
|
||||
|
|
@ -194,21 +193,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
if write_csv_header:
|
||||
csv_writer.writeheader()
|
||||
|
||||
epoch = step // epoch_len
|
||||
epoch_step = step % epoch_len
|
||||
epoch = (step - 1) // epoch_len
|
||||
epoch_step = (step - 1) % epoch_len
|
||||
|
||||
csv_writer.writerow({
|
||||
"step": step + 1,
|
||||
"step": step,
|
||||
"epoch": epoch,
|
||||
"epoch_step": epoch_step + 1,
|
||||
"epoch_step": epoch_step,
|
||||
**values,
|
||||
})
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
assert learn_rate, "Learning rate is empty or 0"
|
||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||
assert batch_size > 0, "Batch size must be positive"
|
||||
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
||||
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
||||
assert data_root, "Dataset directory is empty"
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
|
@ -224,10 +225,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
|
|||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
|
|
@ -255,166 +256,203 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
else:
|
||||
images_embeds_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step >= steps:
|
||||
initial_step = embedding.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
|
||||
if unload:
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, entries in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
with devices.autocast():
|
||||
# c = stack_conds(batch.cond).to(devices.device)
|
||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||
# print(mask)
|
||||
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
del x
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
embedding.step += 1
|
||||
pbar.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_step = _loss_step
|
||||
_loss_step = 0
|
||||
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
steps_done = embedding.step + 1
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
epoch_num = embedding.step // steps_per_epoch
|
||||
epoch_step = embedding.step % steps_per_epoch
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([entry.cond_text for entry in entries])
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
del x
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
#if shared.opts.save_optimizer_state:
|
||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
steps_done = embedding.step + 1
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step % len(ds)
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
do_not_reload_embeddings=True,
|
||||
)
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
preview_text = p.prompt
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||
"loss": f"{losses.mean():.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
do_not_reload_embeddings=True,
|
||||
)
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_index = preview_sampler_index
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
||||
preview_text = p.prompt
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
title = "<{}>".format(data.get('name', '???'))
|
||||
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
try:
|
||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||
except Exception as e:
|
||||
vectorSize = '?'
|
||||
|
||||
shared.state.current_image = image
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
title = "<{}>".format(data.get('name', '???'))
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
try:
|
||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||
except Exception as e:
|
||||
vectorSize = '?'
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {steps_done}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
|||
def preprocess(*args):
|
||||
modules.textual_inversion.preprocess.preprocess(*args)
|
||||
|
||||
return "Preprocessing finished.", ""
|
||||
return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
|
||||
|
||||
|
||||
def train_embedding(*args):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import modules.scripts
|
||||
from modules import sd_samplers
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
|
@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
seed_resize_from_h=seed_resize_from_h,
|
||||
seed_resize_from_w=seed_resize_from_w,
|
||||
seed_enable_extras=seed_enable_extras,
|
||||
sampler_index=sampler_index,
|
||||
sampler_name=sd_samplers.samplers[sampler_index].name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=steps,
|
||||
|
|
|
|||
265
modules/ui.py
265
modules/ui.py
|
|
@ -17,16 +17,13 @@ import gradio.routes
|
|||
import gradio.utils
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
|
||||
from modules.paths import script_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
from modules.deepbooru import get_deepbooru_tags
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.gfpgan_model
|
||||
|
|
@ -69,8 +66,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
|
|||
css_hide_progressbar = """
|
||||
.wrap .m-12 svg { display:none!important; }
|
||||
.wrap .m-12::before { content:"Loading..." }
|
||||
.wrap .z-20 svg { display:none!important; }
|
||||
.wrap .z-20::before { content:"Loading..." }
|
||||
.progress-bar { display:none!important; }
|
||||
.meta-text { display:none!important; }
|
||||
.meta-text-center { display:none!important; }
|
||||
"""
|
||||
|
||||
# Using constants for these since the variation selector isn't visible.
|
||||
|
|
@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||
filenames.append(os.path.basename(txt_fullfn))
|
||||
fullfns.append(txt_fullfn)
|
||||
|
||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||
|
||||
# Make Zip
|
||||
if do_make_zip:
|
||||
|
|
@ -157,81 +157,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||
|
||||
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in pil_image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
||||
return file_obj
|
||||
|
||||
|
||||
# override save to file function so that it also writes PNG info
|
||||
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if (elapsed_m > 0):
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||
|
||||
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def calc_time_left(progress, threshold, label, force_display):
|
||||
|
|
@ -276,7 +202,7 @@ def check_progress_call(id_part):
|
|||
image = gr_show(False)
|
||||
preview_visibility = gr_show(False)
|
||||
|
||||
if opts.show_progress_every_n_steps > 0:
|
||||
if opts.show_progress_every_n_steps != 0:
|
||||
shared.state.set_current_image()
|
||||
image = shared.state.current_image
|
||||
|
||||
|
|
@ -346,7 +272,7 @@ def interrogate(image):
|
|||
|
||||
|
||||
def interrogate_deepbooru(image):
|
||||
prompt = get_deepbooru_tags(image)
|
||||
prompt = deepbooru.model.tag(image)
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
|
||||
|
||||
|
|
@ -475,9 +401,7 @@ def create_toprow(is_img2img):
|
|||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
|
|
@ -563,6 +487,19 @@ def apply_setting(key, value):
|
|||
return value
|
||||
|
||||
|
||||
def update_generation_info(args):
|
||||
generation_info, html_info, img_index = args
|
||||
try:
|
||||
generation_info = json.loads(generation_info)
|
||||
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
||||
return html_info
|
||||
return plaintext_to_html(generation_info["infotexts"][img_index])
|
||||
except Exception:
|
||||
pass
|
||||
# if the json parse or anything else fails, just return the old html_info
|
||||
return html_info
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
|
|
@ -635,6 +572,15 @@ Requested path was: {f}
|
|||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
if tabname == 'txt2img' or tabname == 'img2img':
|
||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||
generation_info_button.click(
|
||||
fn=update_generation_info,
|
||||
_js="(x, y) => [x, y, selected_gallery_index()]",
|
||||
inputs=[generation_info, html_info],
|
||||
outputs=[html_info],
|
||||
preprocess=False
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
|
|
@ -659,7 +605,7 @@ Requested path was: {f}
|
|||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
|
||||
|
||||
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
def create_ui():
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
|
|
@ -667,6 +613,9 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
parameters_copypaste.reset()
|
||||
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
|
||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
|
@ -709,7 +658,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||
|
|
@ -816,7 +765,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
height,
|
||||
]
|
||||
|
||||
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||
|
|
@ -840,11 +792,22 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint'):
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
|
||||
init_img_with_mask_orig = gr.State(None)
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
|
||||
|
||||
def update_orig(image, state):
|
||||
if image is not None:
|
||||
same_size = state is not None and state.size == image.size
|
||||
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
||||
edited = same_size and has_exact_match
|
||||
return image if not edited or state is None else state
|
||||
|
||||
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
|
||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
||||
|
||||
show_mask_alpha = cmd_opts.gradio_inpaint_tool == "color-sketch"
|
||||
mask_alpha = gr.Slider(label="Mask transparency", interactive=show_mask_alpha, visible=show_mask_alpha)
|
||||
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
|
||||
|
||||
with gr.Row():
|
||||
|
|
@ -888,7 +851,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
||||
|
||||
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
|
||||
|
|
@ -932,12 +895,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
img2img_prompt_style2,
|
||||
init_img,
|
||||
init_img_with_mask,
|
||||
init_img_with_mask_orig,
|
||||
init_img_inpaint,
|
||||
init_mask_inpaint,
|
||||
mask_mode,
|
||||
steps,
|
||||
sampler_index,
|
||||
mask_blur,
|
||||
mask_alpha,
|
||||
inpainting_fill,
|
||||
restore_faces,
|
||||
tiling,
|
||||
|
|
@ -973,11 +938,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
img2img_deepbooru.click(
|
||||
fn=interrogate_deepbooru,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
img2img_deepbooru.click(
|
||||
fn=interrogate_deepbooru,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1032,11 +996,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(mask_blur, "Mask blur"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
||||
|
||||
modules.scripts.scripts_current = None
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
|
|
@ -1138,7 +1105,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
|
||||
with gr.Blocks() as modelmerger_interface:
|
||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||
|
|
@ -1150,7 +1117,11 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
|
||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
|
||||
save_as_half = gr.Checkbox(value=False, label="Save as float16")
|
||||
|
||||
with gr.Row():
|
||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
|
||||
save_as_half = gr.Checkbox(value=False, label="Save as float16")
|
||||
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
|
|
@ -1158,7 +1129,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks() as train_interface:
|
||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
|
||||
|
|
@ -1180,7 +1151,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
with gr.Tab(label="Create hypernetwork"):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
|
||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
|
||||
|
|
@ -1207,7 +1178,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
process_split = gr.Checkbox(label='Split oversized images')
|
||||
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
|
||||
process_caption = gr.Checkbox(label='Use BLIP for caption')
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True)
|
||||
|
||||
with gr.Row(visible=False) as process_split_extra_row:
|
||||
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
|
||||
|
|
@ -1224,6 +1195,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
interrupt_preprocessing = gr.Button("Interrupt")
|
||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||
|
||||
process_split.change(
|
||||
|
|
@ -1251,6 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||
|
|
@ -1261,12 +1235,21 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||
with gr.Row():
|
||||
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
|
||||
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
|
||||
with gr.Row():
|
||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
|
||||
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||
|
||||
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||
|
||||
script_callbacks.ui_train_tabs_callback(params)
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
|
|
@ -1345,11 +1328,15 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
train_embedding_name,
|
||||
embedding_learn_rate,
|
||||
batch_size,
|
||||
gradient_step,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
shuffle_tags,
|
||||
tag_drop_out,
|
||||
latent_sampling_method,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
|
|
@ -1370,11 +1357,15 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
train_hypernetwork_name,
|
||||
hypernetwork_learn_rate,
|
||||
batch_size,
|
||||
gradient_step,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
shuffle_tags,
|
||||
tag_drop_out,
|
||||
latent_sampling_method,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
|
|
@ -1393,6 +1384,12 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
outputs=[],
|
||||
)
|
||||
|
||||
interrupt_preprocessing.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def create_setting_component(key, is_quicksettings=False):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
|
|
@ -1417,15 +1414,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
if info.refresh is not None:
|
||||
if is_quicksettings:
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
with gr.Row(variant="compact"):
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -1436,43 +1432,30 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
opts.reorder()
|
||||
|
||||
def run_settings(*args):
|
||||
changed = 0
|
||||
changed = []
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
|
||||
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
|
||||
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||
if comp == dummy_component:
|
||||
continue
|
||||
|
||||
oldval = opts.data.get(key, None)
|
||||
if opts.set(key, value):
|
||||
changed.append(key)
|
||||
|
||||
setattr(opts, key, value)
|
||||
|
||||
if oldval != value:
|
||||
if opts.data_labels[key].onchange is not None:
|
||||
opts.data_labels[key].onchange()
|
||||
|
||||
changed += 1
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
|
||||
return f'{changed} settings changed.', opts.dumpjson()
|
||||
try:
|
||||
opts.save(shared.config_filename)
|
||||
except RuntimeError:
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
|
||||
|
||||
def run_settings_single(value, key):
|
||||
if not opts.same_type(value, opts.data_labels[key].default):
|
||||
return gr.update(visible=True), opts.dumpjson()
|
||||
|
||||
oldval = opts.data.get(key, None)
|
||||
try:
|
||||
setattr(opts, key, value)
|
||||
except Exception:
|
||||
return gr.update(value=oldval), opts.dumpjson()
|
||||
|
||||
if oldval != value:
|
||||
if opts.data_labels[key].onchange is not None:
|
||||
opts.data_labels[key].onchange()
|
||||
if not opts.set(key, value):
|
||||
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
|
||||
|
|
@ -1562,11 +1545,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
shared.state.need_restart = True
|
||||
|
||||
restart_gradio.click(
|
||||
|
||||
fn=request_restart,
|
||||
_js='restart_reload',
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='restart_reload'
|
||||
)
|
||||
|
||||
if column is not None:
|
||||
|
|
@ -1622,9 +1604,9 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||
settings_submit.click(
|
||||
fn=run_settings,
|
||||
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
|
||||
inputs=components,
|
||||
outputs=[result, text_settings],
|
||||
outputs=[text_settings, result],
|
||||
)
|
||||
|
||||
for i, k, item in quicksettings_list:
|
||||
|
|
@ -1636,6 +1618,17 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
outputs=[component, text_settings],
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||
|
||||
def get_settings_values():
|
||||
return [getattr(opts, key) for key in component_keys]
|
||||
|
||||
demo.load(
|
||||
fn=get_settings_values,
|
||||
inputs=[],
|
||||
outputs=[component_dict[k] for k in component_keys],
|
||||
)
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
|
|
@ -1656,6 +1649,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
interp_amount,
|
||||
save_as_half,
|
||||
custom_name,
|
||||
checkpoint_format,
|
||||
],
|
||||
outputs=[
|
||||
submit_result,
|
||||
|
|
@ -1739,7 +1733,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
return demo
|
||||
|
||||
|
||||
def load_javascript(raw_response):
|
||||
def reload_javascript():
|
||||
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
||||
javascript = f'<script>{jsfile.read()}</script>'
|
||||
|
||||
|
|
@ -1755,7 +1749,7 @@ def load_javascript(raw_response):
|
|||
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
|
||||
|
||||
def template_response(*args, **kwargs):
|
||||
res = raw_response(*args, **kwargs)
|
||||
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
||||
res.body = res.body.replace(
|
||||
b'</head>', f'{javascript}</head>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
|
|
@ -1764,4 +1758,5 @@ def load_javascript(raw_response):
|
|||
gradio.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
||||
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
|
||||
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
||||
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ available_extensions = {"extensions": []}
|
|||
|
||||
|
||||
def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
|
||||
|
||||
|
||||
def apply_and_restart(disable_list, update_list):
|
||||
|
|
@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
|
|||
continue
|
||||
|
||||
try:
|
||||
ext.pull()
|
||||
ext.fetch_and_reset_hard()
|
||||
except Exception:
|
||||
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
|
||||
print(f"Error getting updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
|
|
@ -86,7 +86,7 @@ def extension_table():
|
|||
code += f"""
|
||||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
|
||||
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
|
@ -134,19 +134,24 @@ def install_extension_from_url(dirname, url):
|
|||
|
||||
os.rename(tmpdir, target_dir)
|
||||
|
||||
import launch
|
||||
launch.run_extension_installer(target_dir)
|
||||
|
||||
extensions.list_extensions()
|
||||
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def install_extension_from_index(url):
|
||||
def install_extension_from_index(url, hide_tags):
|
||||
ext_table, message = install_extension_from_url(None, url)
|
||||
|
||||
return refresh_available_extensions_from_data(), ext_table, message
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags)
|
||||
|
||||
return code, ext_table, message
|
||||
|
||||
|
||||
def refresh_available_extensions(url):
|
||||
def refresh_available_extensions(url, hide_tags):
|
||||
global available_extensions
|
||||
|
||||
import urllib.request
|
||||
|
|
@ -155,13 +160,25 @@ def refresh_available_extensions(url):
|
|||
|
||||
available_extensions = json.loads(text)
|
||||
|
||||
return url, refresh_available_extensions_from_data(), ''
|
||||
code, tags = refresh_available_extensions_from_data(hide_tags)
|
||||
|
||||
return url, code, gr.CheckboxGroup.update(choices=tags), ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data():
|
||||
def refresh_available_extensions_for_tags(hide_tags):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags)
|
||||
|
||||
return code, ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data(hide_tags):
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
|
||||
tags = available_extensions.get("tags", {})
|
||||
tags_to_hide = set(hide_tags)
|
||||
hidden = 0
|
||||
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="available_extensions">
|
||||
<thead>
|
||||
|
|
@ -178,17 +195,24 @@ def refresh_available_extensions_from_data():
|
|||
name = ext.get("name", "noname")
|
||||
url = ext.get("url", None)
|
||||
description = ext.get("description", "")
|
||||
extension_tags = ext.get("tags", [])
|
||||
|
||||
if url is None:
|
||||
continue
|
||||
|
||||
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
|
||||
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
|
||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
||||
<td>{html.escape(description)}</td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
|
|
@ -199,7 +223,10 @@ def refresh_available_extensions_from_data():
|
|||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
if hidden > 0:
|
||||
code += f"<p>Extension hidden: {hidden}</p>"
|
||||
|
||||
return code, list(tags)
|
||||
|
||||
|
||||
def create_ui():
|
||||
|
|
@ -238,21 +265,30 @@ def create_ui():
|
|||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index],
|
||||
outputs=[available_extensions_index, available_extensions_table, install_result],
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index, hide_tags],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[extension_to_install],
|
||||
inputs=[extension_to_install, hide_tags],
|
||||
outputs=[available_extensions_table, extensions_table, install_result],
|
||||
)
|
||||
|
||||
hide_tags.change(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||
inputs=[hide_tags],
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
|
|
|
|||
62
modules/ui_tempdir.py
Normal file
62
modules/ui_tempdir.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
Savedfile = namedtuple("Savedfile", ["name"])
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
||||
if already_saved_as and os.path.isfile(already_saved_as):
|
||||
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
|
||||
file_obj = Savedfile(already_saved_as)
|
||||
return file_obj
|
||||
|
||||
if shared.opts.temp_dir != "":
|
||||
dir = shared.opts.temp_dir
|
||||
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in pil_image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
||||
return file_obj
|
||||
|
||||
|
||||
# override save to file function so that it also writes PNG info
|
||||
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
|
||||
|
||||
def on_tmpdir_changed():
|
||||
if shared.opts.temp_dir == "" or shared.demo is None:
|
||||
return
|
||||
|
||||
os.makedirs(shared.opts.temp_dir, exist_ok=True)
|
||||
|
||||
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
|
||||
|
||||
|
||||
def cleanup_tmpdr():
|
||||
temp_dir = shared.opts.temp_dir
|
||||
if temp_dir == "" or not os.path.isdir(temp_dir):
|
||||
return
|
||||
|
||||
for root, dirs, files in os.walk(temp_dir, topdown=False):
|
||||
for name in files:
|
||||
_, extension = os.path.splitext(name)
|
||||
if extension != ".png":
|
||||
continue
|
||||
|
||||
filename = os.path.join(root, name)
|
||||
os.remove(filename)
|
||||
|
|
@ -57,10 +57,18 @@ class Upscaler:
|
|||
self.scale = scale
|
||||
dest_w = img.width * scale
|
||||
dest_h = img.height * scale
|
||||
|
||||
for i in range(3):
|
||||
if img.width > dest_w and img.height > dest_h:
|
||||
break
|
||||
shape = (img.width, img.height)
|
||||
|
||||
img = self.do_upscale(img, selected_model)
|
||||
|
||||
if shape == (img.width, img.height):
|
||||
break
|
||||
|
||||
if img.width >= dest_w and img.height >= dest_h:
|
||||
break
|
||||
|
||||
if img.width != dest_w or img.height != dest_h:
|
||||
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue