mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-01-07 01:32:49 -08:00
Merge branch 'master' into hn-activation
This commit is contained in:
commit
4918eb6ce4
70 changed files with 7041 additions and 1309 deletions
|
|
@ -1,29 +1,39 @@
|
|||
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers
|
||||
from modules.extras import run_pnginfo
|
||||
import modules.shared as shared
|
||||
import uvicorn
|
||||
from fastapi import Body, APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from typing import List
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import time
|
||||
import uvicorn
|
||||
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import modules.shared as shared
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
||||
reqDict.pop('upscaler_1')
|
||||
reqDict.pop('upscaler_2')
|
||||
return reqDict
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="png")
|
||||
return base64.b64encode(buffer.getvalue())
|
||||
|
||||
|
||||
class Api:
|
||||
|
|
@ -31,25 +41,22 @@ class Api:
|
|||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
|
||||
|
||||
def __base64_to_image(self, base64_string):
|
||||
# if has a comma, deal with prefix
|
||||
if "," in base64_string:
|
||||
base64_string = base64_string.split(",")[1]
|
||||
imgdata = base64.b64decode(base64_string)
|
||||
# convert base64 to PIL image
|
||||
return Image.open(io.BytesIO(imgdata))
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
|
|
@ -57,40 +64,39 @@ class Api:
|
|||
)
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
|
||||
|
||||
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(img2imgreq.sampler_index)
|
||||
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = self.__base64_to_image(mask)
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
}
|
||||
)
|
||||
|
|
@ -98,31 +104,92 @@ class Api:
|
|||
|
||||
imgs = []
|
||||
for img in init_images:
|
||||
img = self.__base64_to_image(img)
|
||||
img = decode_base64_to_image(img)
|
||||
imgs = [img] * p.batch_size
|
||||
|
||||
p.init_images = imgs
|
||||
# Override object param
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
if (not img2imgreq.include_init_images):
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
|
||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def extrasapi(self):
|
||||
raise NotImplementedError
|
||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def pnginfoapi(self):
|
||||
raise NotImplementedError
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def prepareFiles(file):
|
||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
||||
file.orig_name = file.name
|
||||
return file
|
||||
|
||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
||||
reqDict.pop('imageList')
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
|
||||
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
|
||||
|
||||
return PNGInfoResponse(info=result[1])
|
||||
|
||||
def progressapi(self, req: ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
shared.state.set_current_image()
|
||||
|
||||
current_image = None
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
return {}
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from array import array
|
||||
from inflection import underscore
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
import inspect
|
||||
|
||||
from click import prompt
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
|
|
@ -51,17 +52,17 @@ class PydanticModelGenerator:
|
|||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
|
|
@ -73,11 +74,11 @@ class PydanticModelGenerator:
|
|||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
|
@ -94,15 +95,74 @@ class PydanticModelGenerator:
|
|||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
).generate_model()
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def mod2normal(state_dict):
|
|||
def resrgan2normal(state_dict, nb=23):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||
re8x = 0
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
|
|
@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
|
|||
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
||||
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
|
||||
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
|
||||
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
||||
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
||||
|
||||
if 'conv_up3.weight' in state_dict:
|
||||
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
||||
re8x = 3
|
||||
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
||||
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
||||
|
||||
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
||||
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
||||
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
||||
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
||||
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
|
|
|||
83
modules/extensions.py
Normal file
83
modules/extensions.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
for fetch in repo.remote().fetch("--dry-run"):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def pull(self):
|
||||
repo = git.Repo(self.path)
|
||||
repo.remotes.origin.pull()
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
path = os.path.join(extensions_dir, dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||
extensions.append(extension)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
|
||||
|
|
@ -7,6 +8,10 @@ from PIL import Image
|
|||
import torch
|
||||
import tqdm
|
||||
|
||||
from typing import Callable, List, OrderedDict, Tuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
|
|
@ -17,10 +22,38 @@ import piexif.helper
|
|||
import gradio as gr
|
||||
|
||||
|
||||
cached_images = {}
|
||||
class LruCache(OrderedDict):
|
||||
@dataclass(frozen=True)
|
||||
class Key:
|
||||
image_hash: int
|
||||
info_hash: int
|
||||
args_hash: int
|
||||
|
||||
@dataclass
|
||||
class Value:
|
||||
image: Image.Image
|
||||
info: str
|
||||
|
||||
def __init__(self, max_size: int = 5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_size = max_size
|
||||
|
||||
def get(self, key: LruCache.Key) -> LruCache.Value:
|
||||
ret = super().get(key)
|
||||
if ret is not None:
|
||||
self.move_to_end(key) # Move to end of eviction list
|
||||
return ret
|
||||
|
||||
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
|
||||
self[key] = value
|
||||
while len(self) > self._max_size:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||
cached_images: LruCache = LruCache(max_size=5)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
|
||||
devices.torch_gc()
|
||||
|
||||
imageArr = []
|
||||
|
|
@ -39,7 +72,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
|
||||
if input_dir == '':
|
||||
return outputs, "Please select an input directory.", ''
|
||||
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for img in image_list:
|
||||
try:
|
||||
image = Image.open(img)
|
||||
|
|
@ -56,7 +89,91 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
|
||||
# Extra operation definitions
|
||||
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if gfpgan_visibility < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
|
||||
res = cropped
|
||||
return res
|
||||
|
||||
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
|
||||
nonlocal upscaling_resize
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
return (image, info)
|
||||
|
||||
@dataclass
|
||||
class UpscaleParams:
|
||||
upscaler_idx: int
|
||||
blend_alpha: float
|
||||
|
||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
blended_result: Image.Image = None
|
||||
for upscaler in params:
|
||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
|
||||
info_hash=hash(info),
|
||||
args_hash=hash((upscale_args, upscale_first)))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
|
||||
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
|
||||
else:
|
||||
res, info = cached_entry.image, cached_entry.info
|
||||
|
||||
if blended_result is None:
|
||||
blended_result = res
|
||||
else:
|
||||
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
|
||||
return (blended_result, info)
|
||||
|
||||
# Build a list of operations to run
|
||||
facefix_ops: List[Callable] = []
|
||||
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
|
||||
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
|
||||
|
||||
upscale_ops: List[Callable] = []
|
||||
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
|
||||
|
||||
if upscaling_resize != 0:
|
||||
step_params: List[UpscaleParams] = []
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
|
||||
|
||||
upscale_ops.append(partial(run_upscalers_blend, step_params))
|
||||
|
||||
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
|
||||
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
|
|
@ -64,64 +181,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
|
||||
image = image.convert("RGB")
|
||||
info = ""
|
||||
# Run each operation on each image
|
||||
for op in extras_ops:
|
||||
image, info = op(image, info)
|
||||
|
||||
if gfpgan_visibility > 0:
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if gfpgan_visibility < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if codeformer_visibility > 0:
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
|
||||
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
|
||||
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
|
||||
c = cropped
|
||||
cached_images[key] = c
|
||||
|
||||
return c
|
||||
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
|
||||
image = res
|
||||
|
||||
while len(cached_images) > 2:
|
||||
del cached_images[next(iter(cached_images.keys()))]
|
||||
|
||||
if opts.use_original_name_batch and image_name != None:
|
||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||
else:
|
||||
|
|
@ -141,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
|
||||
return outputs, plaintext_to_html(info), ''
|
||||
|
||||
def clear_cache():
|
||||
cached_images.clear()
|
||||
|
||||
|
||||
def run_pnginfo(image):
|
||||
if image is None:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,25 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
paste_fields = {}
|
||||
bind_list = []
|
||||
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
bind_list.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
|
|
@ -20,6 +31,111 @@ def quote(text):
|
|||
text = text.replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||
normfn = os.path.normpath(filename)
|
||||
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
|
||||
filedata = filedata[0]
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields):
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
import modules.ui
|
||||
if tabname == 'txt2img':
|
||||
modules.ui.txt2img_paste_fields = fields
|
||||
elif tabname == 'img2img':
|
||||
modules.ui.img2img_paste_fields = fields
|
||||
|
||||
|
||||
def integrate_settings_paste_fields(component_dict):
|
||||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
for tabname, info in paste_fields.items():
|
||||
if info["fields"] is not None:
|
||||
info["fields"] += settings_paste_fields
|
||||
|
||||
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
buttons[tab] = gr.Button(f"Send to {tab}")
|
||||
return buttons
|
||||
|
||||
|
||||
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
|
||||
def bind_buttons(buttons, send_image, send_generate_info):
|
||||
bind_list.append([buttons, send_image, send_generate_info])
|
||||
|
||||
|
||||
def run_bind():
|
||||
for buttons, send_image, send_generate_info in bind_list:
|
||||
for tab in buttons:
|
||||
button = buttons[tab]
|
||||
if send_image and paste_fields[tab]["init_img"]:
|
||||
if type(send_image) == gr.Gallery:
|
||||
button.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery",
|
||||
inputs=[send_image],
|
||||
outputs=[paste_fields[tab]["init_img"]],
|
||||
)
|
||||
else:
|
||||
button.click(
|
||||
fn=lambda x: x,
|
||||
inputs=[send_image],
|
||||
outputs=[paste_fields[tab]["init_img"]],
|
||||
)
|
||||
|
||||
if send_generate_info and paste_fields[tab]["fields"] is not None:
|
||||
if send_generate_info in paste_fields:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
|
||||
button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
|
||||
)
|
||||
else:
|
||||
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
|
||||
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{tab}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
|
|
@ -68,7 +184,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, js=None):
|
||||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(script_path, "params.txt")
|
||||
|
|
@ -106,7 +222,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
|||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=js,
|
||||
_js=jsfunc,
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from statistics import stdev, mean
|
|||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
|
|
@ -220,13 +221,16 @@ def list_hypernetworks(path):
|
|||
res = {}
|
||||
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
res[name] = filename
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name] = filename
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
if path is not None:
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
|
|
@ -343,7 +347,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
|
|
@ -369,39 +375,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
weights = hypernetwork.weights()
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
|
||||
size = len(ds.indexes)
|
||||
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
losses = torch.zeros((size,))
|
||||
previous_mean_losses = [0]
|
||||
previous_mean_loss = 0
|
||||
print("Mean loss of {} elements".format(size))
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step > steps:
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
steps_without_grad = 0
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
|
@ -440,7 +451,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
|
||||
optimizer.step()
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
|
||||
if len(previous_mean_losses) > 1:
|
||||
|
|
@ -450,19 +463,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||
pbar.set_description(dataset_loss_info)
|
||||
|
||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||
hypernetwork.save(last_saved_file)
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
"loss": f"{previous_mean_loss:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
|
@ -499,7 +512,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
|
@ -515,13 +528,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
"""
|
||||
|
||||
report_statistics(loss_dict)
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
|
||||
hypernetwork.name = hypernetwork_name
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||
hypernetwork.save(filename)
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
|
||||
return hypernetwork, filename
|
||||
|
||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||
old_hypernetwork_name = hypernetwork.name
|
||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
try:
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.name = hypernetwork_name
|
||||
hypernetwork.save(filename)
|
||||
except:
|
||||
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
hypernetwork.name = old_hypernetwork_name
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
|
|||
from modules import devices, sd_hijack, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
|
|
|
|||
|
|
@ -300,8 +300,8 @@ class FilenameGenerator:
|
|||
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||
'steps': lambda self: self.p and self.p.steps,
|
||||
'cfg': lambda self: self.p and self.p.cfg_scale,
|
||||
'width': lambda self: self.p and self.p.width,
|
||||
'height': lambda self: self.p and self.p.height,
|
||||
'width': lambda self: self.image.width,
|
||||
'height': lambda self: self.image.height,
|
||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
|
||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||
|
|
@ -315,10 +315,11 @@ class FilenameGenerator:
|
|||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
def __init__(self, p, seed, prompt):
|
||||
def __init__(self, p, seed, prompt, image):
|
||||
self.p = p
|
||||
self.seed = seed
|
||||
self.prompt = prompt
|
||||
self.image = image
|
||||
|
||||
def prompt_no_style(self):
|
||||
if self.p is None or self.prompt is None:
|
||||
|
|
@ -449,18 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
txt_fullfn (`str` or None):
|
||||
If a text file is saved for this image, this will be its full path. Otherwise None.
|
||||
"""
|
||||
namegen = FilenameGenerator(p, seed, prompt)
|
||||
|
||||
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if existing_info is not None:
|
||||
for k, v in existing_info.items():
|
||||
pnginfo.add_text(k, str(v))
|
||||
|
||||
pnginfo.add_text(pnginfo_section_name, info)
|
||||
else:
|
||||
pnginfo = None
|
||||
namegen = FilenameGenerator(p, seed, prompt, image)
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
|
@ -489,19 +479,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
if add_number:
|
||||
basecount = get_next_sequence_number(path, basename)
|
||||
fullfn = None
|
||||
fullfn_without_extension = None
|
||||
for i in range(500):
|
||||
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
||||
if not os.path.exists(fullfn):
|
||||
break
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, file_decoration)
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, forced_filename)
|
||||
|
||||
pnginfo = existing_info or {}
|
||||
if info is not None:
|
||||
pnginfo[pnginfo_section_name] = info
|
||||
|
||||
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
||||
script_callbacks.before_image_saved_callback(params)
|
||||
|
||||
image = params.image
|
||||
fullfn = params.filename
|
||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
|
||||
def exif_bytes():
|
||||
return piexif.dump({
|
||||
|
|
@ -510,12 +508,21 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
},
|
||||
})
|
||||
|
||||
if extension.lower() in ("jpg", "jpeg", "webp"):
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
if opts.enable_pnginfo:
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
|
||||
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
image.save(fullfn, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
piexif.insert(exif_bytes(), fullfn)
|
||||
else:
|
||||
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||
image.save(fullfn, quality=opts.jpeg_quality)
|
||||
|
||||
target_side_length = 4000
|
||||
oversize = image.width > target_side_length or image.height > target_side_length
|
||||
|
|
@ -538,7 +545,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
else:
|
||||
txt_fullfn = None
|
||||
|
||||
script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
|
||||
script_callbacks.image_saved_callback(params)
|
||||
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import modules.scripts
|
|||
def process_batch(p, input_dir, output_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
||||
images = shared.listfiles(input_dir)
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
|
|
@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
break
|
||||
|
||||
img = Image.open(image)
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
|
|
@ -53,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
filename = f"{left}-{n}{right}"
|
||||
|
||||
if not save_normally:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
|
|
@ -61,19 +64,26 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
is_batch = mode == 2
|
||||
|
||||
if is_inpaint:
|
||||
# Drawn mask
|
||||
if mask_mode == 0:
|
||||
image = init_img_with_mask['image']
|
||||
mask = init_img_with_mask['mask']
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert('RGB')
|
||||
# Uploaded mask
|
||||
else:
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
# No mask
|
||||
else:
|
||||
image = init_img
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
|
|
@ -128,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
|
|
|
|||
|
|
@ -56,9 +56,9 @@ class InterrogateModels:
|
|||
import clip
|
||||
|
||||
if self.running_on_cpu:
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
||||
else:
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
||||
|
||||
model.eval()
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
|
|
|||
|
|
@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
# see below for register_forward_pre_hook;
|
||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||
# useless here, and we just replace those methods
|
||||
def first_stage_model_encode_wrap(self, encoder, x):
|
||||
send_me_to_gpu(self, None)
|
||||
return encoder(x)
|
||||
|
||||
def first_stage_model_decode_wrap(self, decoder, z):
|
||||
send_me_to_gpu(self, None)
|
||||
return decoder(z)
|
||||
first_stage_model = sd_model.first_stage_model
|
||||
first_stage_model_encode = sd_model.first_stage_model.encode
|
||||
first_stage_model_decode = sd_model.first_stage_model.decode
|
||||
|
||||
def first_stage_model_encode_wrap(x):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_encode(x)
|
||||
|
||||
def first_stage_model_decode_wrap(z):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
|
|
@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
# register hooks for those the first two models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
||||
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if use_medvram:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
|
|||
ratio_processing = processing_width / processing_height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) * ratio_processing
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2-y1))
|
||||
y1 -= desired_height_diff//2
|
||||
y2 += desired_height_diff - desired_height_diff//2
|
||||
|
|
|
|||
|
|
@ -85,6 +85,9 @@ def cleanup_models():
|
|||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(models_path, "BSRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path, ".pth")
|
||||
src_path = os.path.join(root_path, "gfpgan")
|
||||
dest_path = os.path.join(models_path, "GFPGAN")
|
||||
move_files(src_path, dest_path)
|
||||
|
|
|
|||
|
|
@ -77,9 +77,8 @@ def get_correct_sampler(p):
|
|||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
|
|
@ -109,13 +108,14 @@ class StableDiffusionProcessing():
|
|||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
||||
self.paste_to = None
|
||||
self.color_corrections = None
|
||||
self.denoising_strength: float = 0
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = opts.ddim_discretize
|
||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
self.s_noise = s_noise or opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
|
|
@ -129,13 +129,75 @@ class StableDiffusionProcessing():
|
|||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1)
|
||||
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||
|
||||
# Handle the different mask inputs
|
||||
if image_mask is not None:
|
||||
if torch.is_tensor(image_mask):
|
||||
conditioning_mask = image_mask
|
||||
else:
|
||||
conditioning_mask = np.array(image_mask.convert("L"))
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
else:
|
||||
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
||||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
|
||||
conditioning_image = torch.lerp(
|
||||
source_image,
|
||||
source_image * (1.0 - conditioning_mask),
|
||||
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
||||
)
|
||||
|
||||
# Encode the new masked image using first stage of network.
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
pass
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
self.sd_model = None
|
||||
self.sampler = None
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
|
|
@ -330,6 +392,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
|
|
@ -351,6 +414,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
|
||||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
for k, v in stored_opts.items():
|
||||
opts.data[k] = v
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
if type(p.prompt) == list:
|
||||
|
|
@ -396,7 +475,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.run_alwayson_scripts(p)
|
||||
p.scripts.process(p)
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
|
@ -419,7 +498,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
if (len(prompts) == 0):
|
||||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
with devices.autocast():
|
||||
|
|
@ -434,7 +513,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
|
||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||
|
|
@ -508,7 +587,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
devices.torch_gc()
|
||||
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
|
@ -556,43 +641,41 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
|
||||
def create_dummy_mask(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
else:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
return samples
|
||||
|
||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||
|
||||
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
|
||||
def save_intermediate(image, index):
|
||||
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
||||
return
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
image = sd_samplers.sample_to_image(image, index)
|
||||
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
||||
|
||||
if opts.use_scale_latent_for_hires_fix:
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
|
||||
# Avoid making the inpainting conditioning unless necessary as
|
||||
# this does need some extra compute to decode / encode the image again.
|
||||
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
||||
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
||||
else:
|
||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
else:
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
|
@ -602,6 +685,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = Image.fromarray(x_sample)
|
||||
|
||||
save_intermediate(image, i)
|
||||
|
||||
image = images.resize_image(0, image, self.width, self.height)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
|
|
@ -613,6 +699,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
|
||||
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
||||
|
||||
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
|
|
@ -623,7 +711,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
return samples
|
||||
|
||||
|
|
@ -755,36 +843,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
if self.image_mask is not None:
|
||||
conditioning_mask = np.array(self.image_mask.convert("L"))
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
else:
|
||||
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
||||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
conditioning_mask = conditioning_mask.to(image.device)
|
||||
conditioning_image = image * (1.0 - conditioning_mask)
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||
else:
|
||||
self.image_conditioning = torch.zeros(
|
||||
self.init_latent.shape[0], 5, 1, 1,
|
||||
dtype=self.init_latent.dtype,
|
||||
device=self.init_latent.device
|
||||
)
|
||||
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
|
@ -795,4 +856,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
del x
|
||||
devices.torch_gc()
|
||||
|
||||
return samples
|
||||
return samples
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
|
|
|
|||
|
|
@ -2,23 +2,73 @@ import sys
|
|||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callbacks_app_started = []
|
||||
callbacks_model_loaded = []
|
||||
callbacks_ui_tabs = []
|
||||
callbacks_ui_settings = []
|
||||
callbacks_before_image_saved = []
|
||||
callbacks_image_saved = []
|
||||
callbacks_cfg_denoiser = []
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
callbacks_model_loaded.clear()
|
||||
callbacks_ui_tabs.clear()
|
||||
callbacks_ui_settings.clear()
|
||||
callbacks_before_image_saved.clear()
|
||||
callbacks_image_saved.clear()
|
||||
callbacks_cfg_denoiser.clear()
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callbacks_app_started:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
|
|
@ -49,14 +99,30 @@ def ui_settings_callback():
|
|||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def image_saved_callback(image, p, fullfn, txt_fullfn):
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_before_image_saved:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
try:
|
||||
c.callback(image, p, fullfn, txt_fullfn)
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in callbacks_cfg_denoiser:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
|
@ -64,6 +130,11 @@ def add_callback(callbacks, fun):
|
|||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def on_app_started(callback):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callbacks_app_started, callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
|
|
@ -90,11 +161,26 @@ def on_ui_settings(callback):
|
|||
add_callback(callbacks_ui_settings, callback)
|
||||
|
||||
|
||||
def on_save_imaged(callback):
|
||||
"""register a function to be called after modules.images.save_image is called.
|
||||
The callback is called with three arguments:
|
||||
- p - procesing object (or a dummy object with same fields if the image is saved using save button)
|
||||
- fullfn - image filename
|
||||
- txt_fullfn - text file with parameters; may be None
|
||||
def on_before_image_saved(callback):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callbacks_before_image_saved, callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callbacks_image_saved, callback)
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callbacks_cfg_denoiser, callback)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import modules.ui as ui
|
|||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks
|
||||
from modules import shared, paths, script_callbacks, extensions
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
|
@ -18,6 +18,9 @@ class Script:
|
|||
args_to = None
|
||||
alwayson = False
|
||||
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
group = None
|
||||
|
||||
infotext_fields = None
|
||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
||||
|
|
@ -64,7 +67,16 @@ class Script:
|
|||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
scripts. You can modify the processing object (p) here, inject hooks, etc.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
@ -98,17 +110,8 @@ def list_scripts(scriptdirname, extension):
|
|||
for filename in sorted(os.listdir(basedir)):
|
||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||
|
||||
extdir = os.path.join(paths.script_path, "extensions")
|
||||
if os.path.exists(extdir):
|
||||
for dirname in sorted(os.listdir(extdir)):
|
||||
dirpath = os.path.join(extdir, dirname)
|
||||
scriptdirpath = os.path.join(dirpath, scriptdirname)
|
||||
|
||||
if not os.path.isdir(scriptdirpath):
|
||||
continue
|
||||
|
||||
for filename in sorted(os.listdir(scriptdirpath)):
|
||||
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
|
||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
|
|
@ -118,11 +121,7 @@ def list_scripts(scriptdirname, extension):
|
|||
def list_files_with_name(filename):
|
||||
res = []
|
||||
|
||||
dirs = [paths.script_path]
|
||||
|
||||
extdir = os.path.join(paths.script_path, "extensions")
|
||||
if os.path.exists(extdir):
|
||||
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
|
||||
for dirpath in dirs:
|
||||
if not os.path.isdir(dirpath):
|
||||
|
|
@ -222,8 +221,6 @@ class ScriptRunner:
|
|||
|
||||
for control in controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
if not script.alwayson:
|
||||
control.visible = False
|
||||
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
|
|
@ -233,40 +230,41 @@ class ScriptRunner:
|
|||
script.args_to = len(inputs)
|
||||
|
||||
for script in self.alwayson_scripts:
|
||||
with gr.Group():
|
||||
with gr.Group() as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
||||
script.group = group
|
||||
|
||||
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
||||
dropdown.save_to_config = True
|
||||
inputs[0] = dropdown
|
||||
|
||||
for script in self.selectable_scripts:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
with gr.Group(visible=False) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
script.group = group
|
||||
|
||||
def select_script(script_index):
|
||||
if 0 < script_index <= len(self.selectable_scripts):
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
else:
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||
|
||||
return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
|
||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||
|
||||
def init_field(title):
|
||||
"""called when an initial value is set from ui-config.json to show script's UI components"""
|
||||
|
||||
if title == 'None':
|
||||
return
|
||||
|
||||
script_index = self.titles.index(title)
|
||||
script = self.selectable_scripts[script_index]
|
||||
for i in range(script.args_from, script.args_to):
|
||||
inputs[i].visible = True
|
||||
self.selectable_scripts[script_index].group.visible = True
|
||||
|
||||
dropdown.init_field = init_field
|
||||
|
||||
dropdown.change(
|
||||
fn=select_script,
|
||||
inputs=[dropdown],
|
||||
outputs=inputs
|
||||
outputs=[script.group for script in self.selectable_scripts]
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
|
@ -289,13 +287,22 @@ class ScriptRunner:
|
|||
|
||||
return processed
|
||||
|
||||
def run_alwayson_scripts(self, p):
|
||||
def process(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process(p, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
|
||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess(p, processed, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def reload_sources(self, cache):
|
||||
|
|
|
|||
|
|
@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
|
|||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
|
||||
self.layers = None
|
||||
self.circular_enabled = False
|
||||
self.clip = None
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared, modelloader, devices, script_callbacks
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
|
||||
|
|
@ -35,8 +37,10 @@ def setup_model():
|
|||
list_models()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
return sorted([x.title for x in checkpoints_list.values()])
|
||||
def checkpoint_tiles():
|
||||
convert = lambda name: int(name) if name.isdigit() else name.lower()
|
||||
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
|
||||
|
||||
def list_models():
|
||||
|
|
@ -155,14 +159,15 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
return pl_sd
|
||||
|
||||
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info):
|
||||
def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
|
||||
if checkpoint_info not in checkpoints_loaded:
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
checkpoint_key = checkpoint_info
|
||||
|
||||
if checkpoint_key not in checkpoints_loaded:
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
|
||||
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
||||
|
|
@ -170,42 +175,47 @@ def load_model_weights(model, checkpoint_info):
|
|||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
missing, extra = model.load_state_dict(sd, strict=False)
|
||||
del pl_sd
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
||||
|
||||
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
|
||||
if os.path.exists(vae_file):
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
model.first_stage_model.load_state_dict(vae_dict)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# if PR #4035 were to get merged, restore base VAE first before caching
|
||||
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
else:
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
|
||||
vae_message = f" with {vae_name} VAE" if vae_name else ""
|
||||
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
|
||||
checkpoints_loaded.move_to_end(checkpoint_key)
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_key])
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
||||
sd_vae.load_vae(model, vae_file)
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
|
|
@ -214,6 +224,12 @@ def load_model(checkpoint_info=None):
|
|||
if checkpoint_info.config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_info.config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
|
|
@ -227,6 +243,7 @@ def load_model(checkpoint_info=None):
|
|||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
|
|
@ -246,14 +263,18 @@ def load_model(checkpoint_info=None):
|
|||
return sd_model
|
||||
|
||||
|
||||
def reload_model_weights(sd_model, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from math import floor
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
|
@ -11,6 +12,7 @@ from modules import prompt_parser, devices, processing, images
|
|||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
|
@ -91,8 +93,8 @@ def single_sample_to_image(sample):
|
|||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples):
|
||||
return single_sample_to_image(samples[0])
|
||||
def sample_to_image(samples, index=0):
|
||||
return single_sample_to_image(samples[index])
|
||||
|
||||
|
||||
def samples_to_image_grid(samples):
|
||||
|
|
@ -205,17 +207,22 @@ class VanillaStableDiffusionSampler:
|
|||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
except Exception:
|
||||
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
|
|
@ -239,18 +246,14 @@ class VanillaStableDiffusionSampler:
|
|||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = steps or p.steps
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
except Exception:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
|
@ -278,6 +281,12 @@ class CFGDenoiser(torch.nn.Module):
|
|||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
|
||||
|
|
|
|||
207
modules/sd_vae.py
Normal file
207
modules/sd_vae.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
import torch
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks
|
||||
from modules.paths import models_path
|
||||
import glob
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
vae_dir = "VAE"
|
||||
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
||||
|
||||
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
default_vae_dict = {"auto": "auto", "None": "None"}
|
||||
default_vae_list = ["auto", "None"]
|
||||
|
||||
|
||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||
vae_dict = dict(default_vae_dict)
|
||||
vae_list = list(default_vae_list)
|
||||
first_load = True
|
||||
|
||||
|
||||
base_vae = None
|
||||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
return None
|
||||
|
||||
|
||||
def store_base_vae(model):
|
||||
global base_vae, checkpoint_info
|
||||
if checkpoint_info != model.sd_checkpoint_info:
|
||||
base_vae = model.first_stage_model.state_dict().copy()
|
||||
checkpoint_info = model.sd_checkpoint_info
|
||||
|
||||
|
||||
def delete_base_vae():
|
||||
global base_vae, checkpoint_info
|
||||
base_vae = None
|
||||
checkpoint_info = None
|
||||
|
||||
|
||||
def restore_base_vae(model):
|
||||
global base_vae, checkpoint_info
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
load_vae_dict(model, base_vae)
|
||||
delete_base_vae()
|
||||
|
||||
|
||||
def get_filename(filepath):
|
||||
return os.path.splitext(os.path.basename(filepath))[0]
|
||||
|
||||
|
||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||
global vae_dict, vae_list
|
||||
res = {}
|
||||
candidates = [
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
|
||||
]
|
||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||
candidates.append(shared.cmd_opts.vae_path)
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
res[name] = filepath
|
||||
vae_list.clear()
|
||||
vae_list.extend(default_vae_list)
|
||||
vae_list.extend(list(res.keys()))
|
||||
vae_dict.clear()
|
||||
vae_dict.update(res)
|
||||
vae_dict.update(default_vae_dict)
|
||||
return vae_list
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print("VAE provided as function argument doesn't exist")
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print("VAE provided as command line argument doesn't exist")
|
||||
# else, we load from settings
|
||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||
# if saved VAE settings isn't recognized, fallback to auto
|
||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||
# if VAE selected but not found, fallback to auto
|
||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print("Selected VAE doesn't exist")
|
||||
# vae-path cmd arg takes priority for auto
|
||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
print("Using VAE provided as command line argument")
|
||||
# if still not found, try look for ".vae.pt" beside model
|
||||
model_path = os.path.splitext(checkpoint_file)[0]
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.pt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print("Using VAE found beside selected model")
|
||||
# if still not found, try look for ".vae.ckpt" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.ckpt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print("Using VAE found beside selected model")
|
||||
# No more fallbacks for auto
|
||||
if vae_file == "auto":
|
||||
vae_file = None
|
||||
# Last check, just because
|
||||
if vae_file and not os.path.exists(vae_file):
|
||||
vae_file = None
|
||||
|
||||
return vae_file
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
if vae_file:
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
load_vae_dict(model, vae_dict_1)
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_list.append(vae_opt)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
"""
|
||||
# Save current VAE to VAE settings, maybe? will it work?
|
||||
if save_settings:
|
||||
if vae_file is None:
|
||||
vae_opt = "None"
|
||||
|
||||
# shared.opts.sd_vae = vae_opt
|
||||
"""
|
||||
|
||||
first_load = False
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def load_vae_dict(model, vae_dict_1=None):
|
||||
if vae_dict_1:
|
||||
store_base_vae(model)
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
else:
|
||||
restore_base_vae()
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print(f"VAE Weights loaded.")
|
||||
return sd_model
|
||||
|
|
@ -4,6 +4,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
|
|
@ -14,7 +15,7 @@ import modules.memmon
|
|||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers, sd_models, localization
|
||||
from modules import sd_samplers, sd_models, localization, sd_vae
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
|
@ -40,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
|||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
|
|
@ -51,6 +52,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
|||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||
|
|
@ -82,9 +84,10 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
|
|||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
restricted_opts = [
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
"outdir_samples",
|
||||
|
|
@ -94,7 +97,9 @@ restricted_opts = [
|
|||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
]
|
||||
}
|
||||
|
||||
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
|
||||
|
|
@ -131,6 +136,8 @@ class State:
|
|||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
need_restart = False
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
|
|
@ -143,8 +150,52 @@ class State:
|
|||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
def get_job_timestamp(self):
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.skipped,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
def end(self):
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
||||
state = State()
|
||||
|
|
@ -204,6 +255,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||
|
||||
|
|
@ -255,20 +308,22 @@ options_templates.update(options_section(('system', "System"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
|
|
@ -303,6 +358,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
|
|
@ -322,6 +378,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
}))
|
||||
|
||||
options_templates.update()
|
||||
|
||||
|
||||
class Options:
|
||||
data = None
|
||||
|
|
@ -333,8 +395,9 @@ class Options:
|
|||
|
||||
def __setattr__(self, key, value):
|
||||
if self.data is not None:
|
||||
if key in self.data:
|
||||
if key in self.data or key in self.data_labels:
|
||||
self.data[key] = value
|
||||
return
|
||||
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
|
|
@ -375,11 +438,12 @@ class Options:
|
|||
if bad_settings > 0:
|
||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||
|
||||
def onchange(self, key, func):
|
||||
def onchange(self, key, func, call=True):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
func()
|
||||
if call:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
||||
|
|
@ -449,3 +513,8 @@ total_tqdm = TotalTQDM()
|
|||
|
||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
|
|||
self.lines = lines
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
|
|
@ -86,12 +88,12 @@ class PersonalizedBase(Dataset):
|
|||
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.initial_indexes = np.arange(len(self.dataset))
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
|
||||
self.indexes = np.random.permutation(self.dataset_length)
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
|
|
|
|||
|
|
@ -4,30 +4,37 @@ import tqdm
|
|||
class LearnScheduleIterator:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
"""
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
|
||||
"""
|
||||
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
for i, pair in enumerate(pairs):
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
if step > cur_step:
|
||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||
self.maxit += 1
|
||||
if step > max_steps:
|
||||
try:
|
||||
for i, pair in enumerate(pairs):
|
||||
if not pair.strip():
|
||||
continue
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
if step > cur_step:
|
||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||
self.maxit += 1
|
||||
if step > max_steps:
|
||||
return
|
||||
elif step == -1:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
elif step == -1:
|
||||
else:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
else:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
assert self.rates
|
||||
except (ValueError, AssertionError):
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
@ -52,7 +59,7 @@ class LearnRateScheduler:
|
|||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if step_number <= self.end_step:
|
||||
if step_number < self.end_step:
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import csv
|
|||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ class EmbeddingDatabase:
|
|||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('hash', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
|
||||
|
|
@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
|
@ -182,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
|
||||
if step % shared.opts.training_write_csv_every != 0:
|
||||
if (step + 1) % shared.opts.training_write_csv_every != 0:
|
||||
return
|
||||
|
||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||
|
||||
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
||||
|
|
@ -194,18 +195,39 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
csv_writer.writeheader()
|
||||
|
||||
epoch = step // epoch_len
|
||||
epoch_step = step - epoch * epoch_len
|
||||
epoch_step = step % epoch_len
|
||||
|
||||
csv_writer.writerow({
|
||||
"step": step + 1,
|
||||
"epoch": epoch + 1,
|
||||
"epoch": epoch,
|
||||
"epoch_step": epoch_step + 1,
|
||||
**values,
|
||||
})
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
assert learn_rate, "Learning rate is empty or 0"
|
||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||
assert batch_size > 0, "Batch size must be positive"
|
||||
assert data_root, "Dataset directory is empty"
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
assert template_file, "Prompt template file is empty"
|
||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||
assert steps, "Max steps is empty or 0"
|
||||
assert isinstance(steps, int), "Max steps must be integer"
|
||||
assert steps > 0 , "Max steps must be positive"
|
||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||
assert create_image_every >= 0 , "Create image must be positive or 0"
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
|
|
@ -213,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||
unload = shared.opts.unload_models_when_training
|
||||
|
||||
if save_embedding_every > 0:
|
||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||
|
|
@ -231,31 +254,38 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
os.makedirs(images_embeds_dir, exist_ok=True)
|
||||
else:
|
||||
images_embeds_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, entries in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
|
@ -279,15 +309,18 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
steps_done = embedding.step + 1
|
||||
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||
epoch_step = embedding.step % len(ds)
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
embedding.save(last_saved_file)
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||
|
|
@ -295,8 +328,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
|
|
@ -325,11 +361,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
shared.state.current_image = image
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
|
|
@ -345,7 +384,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, embedding.step)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
|
@ -353,8 +392,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
|
@ -369,11 +407,27 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
</p>
|
||||
"""
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
embedding.cached_checksum = None
|
||||
embedding.save(filename)
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
||||
try:
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
embedding.name = embedding_name
|
||||
embedding.save(filename)
|
||||
except:
|
||||
embedding.sd_checkpoint = old_sd_checkpoint
|
||||
embedding.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
embedding.name = old_embedding_name
|
||||
embedding.cached_checksum = old_cached_checksum
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -25,8 +25,10 @@ def train_embedding(*args):
|
|||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||
|
||||
|
|
@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
|
|||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.apply_optimizations()
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
|
|
|
|||
416
modules/ui.py
416
modules/ui.py
|
|
@ -1,6 +1,4 @@
|
|||
import base64
|
||||
import html
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import mimetypes
|
||||
|
|
@ -18,15 +16,10 @@ import gradio as gr
|
|||
import gradio.routes
|
||||
import gradio.utils
|
||||
import numpy as np
|
||||
import piexif
|
||||
import torch
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
|
||||
from modules.paths import script_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
|
@ -35,7 +28,7 @@ if cmd_opts.deepdanbooru:
|
|||
from modules.deepbooru import get_deepbooru_tags
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.gfpgan_model
|
||||
import modules.hypernetworks.ui
|
||||
import modules.ldsr_model
|
||||
|
|
@ -49,13 +42,11 @@ from modules.sd_hijack import model_hijack
|
|||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
txt2img_paste_fields = []
|
||||
img2img_paste_fields = []
|
||||
|
||||
|
||||
if not cmd_opts.share and not cmd_opts.listen:
|
||||
# fix gradio phoning home
|
||||
|
|
@ -98,37 +89,11 @@ def plaintext_to_html(text):
|
|||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||
return text
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||
normfn = os.path.normpath(filename)
|
||||
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
|
||||
filedata = filedata[0]
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
|
||||
def send_gradio_gallery_to_image(x):
|
||||
if len(x) == 0:
|
||||
return None
|
||||
|
||||
return image_from_url_text(x[0])
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
import csv
|
||||
filenames = []
|
||||
|
|
@ -192,7 +157,6 @@ def save_files(js_data, images, do_make_zip, index):
|
|||
|
||||
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
|
|
@ -313,15 +277,7 @@ def check_progress_call(id_part):
|
|||
preview_visibility = gr_show(False)
|
||||
|
||||
if opts.show_progress_every_n_steps > 0:
|
||||
if shared.parallel_processing_allowed:
|
||||
|
||||
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
|
||||
if opts.show_progress_grid:
|
||||
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
|
||||
else:
|
||||
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
|
||||
shared.state.current_image_sampling_step = shared.state.sampling_step
|
||||
|
||||
shared.state.set_current_image()
|
||||
image = shared.state.current_image
|
||||
|
||||
if image is None:
|
||||
|
|
@ -626,10 +582,90 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
|
|||
return refresh_button
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Group():
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or outdir),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
if tabname != "extras":
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
else:
|
||||
html_info_x = gr.HTML()
|
||||
html_info = gr.HTML()
|
||||
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
|
||||
|
||||
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
reload_javascript()
|
||||
|
||||
parameters_copypaste.reset()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
|
|
@ -675,30 +711,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
|
||||
with gr.Group():
|
||||
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
||||
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
save = gr.Button('Save')
|
||||
send_to_img2img = gr.Button('Send to img2img')
|
||||
send_to_inpaint = gr.Button('Send to inpaint')
|
||||
send_to_extras = gr.Button('Send to extras')
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
|
@ -756,23 +770,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
outputs=[hr_options],
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
txt2img_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
|
|
@ -784,7 +781,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
]
|
||||
)
|
||||
|
||||
global txt2img_paste_fields
|
||||
txt2img_paste_fields = [
|
||||
(txt2img_prompt, "Prompt"),
|
||||
(txt2img_negative_prompt, "Negative prompt"),
|
||||
|
|
@ -807,6 +803,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
(firstphase_height, "First pass size-2"),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||
|
||||
txt2img_preview_params = [
|
||||
txt2img_prompt,
|
||||
|
|
@ -893,30 +890,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
|
||||
with gr.Group():
|
||||
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
||||
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
save = gr.Button('Save')
|
||||
img2img_send_to_img2img = gr.Button('Send to img2img')
|
||||
img2img_send_to_inpaint = gr.Button('Send to inpaint')
|
||||
img2img_send_to_extras = gr.Button('Send to extras')
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
|
@ -1003,25 +978,9 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
fn=interrogate_deepbooru,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
img2img_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_img2img_tokens",
|
||||
|
|
@ -1055,7 +1014,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
outputs=[prompt, negative_prompt, style1, style2],
|
||||
)
|
||||
|
||||
global img2img_paste_fields
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
|
||||
img2img_paste_fields = [
|
||||
(img2img_prompt, "Prompt"),
|
||||
(img2img_negative_prompt, "Negative prompt"),
|
||||
|
|
@ -1074,7 +1034,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
(denoising_strength, "Denoising strength"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
]
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
|
|
@ -1087,26 +1048,24 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||
|
||||
with gr.TabItem('Batch from Directory'):
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
|
||||
placeholder="A directory on the same machine where the server is running."
|
||||
)
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
|
||||
placeholder="Leave blank to save images to the default path."
|
||||
)
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
|
||||
show_extras_results = gr.Checkbox(label='Show result images', value=True)
|
||||
|
||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
||||
|
||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||
with gr.TabItem('Scale by'):
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
|
||||
with gr.TabItem('Scale to'):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
|
||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
|
|
@ -1119,17 +1078,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
|
||||
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
|
||||
|
||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
result_images = gr.Gallery(label="Result", show_label=False)
|
||||
html_info_x = gr.HTML()
|
||||
html_info = gr.HTML()
|
||||
extras_send_to_img2img = gr.Button('Send to img2img')
|
||||
extras_send_to_inpaint = gr.Button('Send to inpaint')
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||
with gr.Group():
|
||||
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
|
||||
|
||||
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
|
||||
|
||||
submit.click(
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
|
|
@ -1152,6 +1104,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
extras_upscaler_1,
|
||||
extras_upscaler_2,
|
||||
extras_upscaler_2_visibility,
|
||||
upscale_before_face_fix,
|
||||
],
|
||||
outputs=[
|
||||
result_images,
|
||||
|
|
@ -1159,19 +1112,11 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
html_info,
|
||||
]
|
||||
)
|
||||
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
||||
|
||||
extras_send_to_img2img.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_img2img",
|
||||
inputs=[result_images],
|
||||
outputs=[init_img],
|
||||
)
|
||||
|
||||
extras_send_to_inpaint.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_inpaint",
|
||||
inputs=[result_images],
|
||||
outputs=[init_img_with_mask],
|
||||
extras_image.change(
|
||||
fn=modules.extras.clear_cache,
|
||||
inputs=[], outputs=[]
|
||||
)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||
|
|
@ -1183,17 +1128,16 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
html = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
html2 = gr.HTML()
|
||||
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
||||
parameters_copypaste.bind_buttons(buttons, image, generation_info)
|
||||
|
||||
image.change(
|
||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||
inputs=[image],
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as modelmerger_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
|
|
@ -1238,8 +1182,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
|
||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Normal is default, for experiments, relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
|
||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||
|
|
@ -1491,28 +1435,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
script_callbacks.ui_settings_callback()
|
||||
opts.reorder()
|
||||
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
def run_settings(*args):
|
||||
changed = 0
|
||||
|
||||
|
|
@ -1584,8 +1506,9 @@ Requested path was: {f}
|
|||
column = None
|
||||
with gr.Row(elem_id="settings").style(equal_height=False):
|
||||
for i, (k, item) in enumerate(opts.data_labels.items()):
|
||||
section_must_be_skipped = item.section[0] is None
|
||||
|
||||
if previous_section != item.section:
|
||||
if previous_section != item.section and not section_must_be_skipped:
|
||||
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
|
||||
if column is not None:
|
||||
column.__exit__()
|
||||
|
|
@ -1604,6 +1527,8 @@ Requested path was: {f}
|
|||
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
|
||||
quicksettings_list.append((i, k, item))
|
||||
components.append(dummy_component)
|
||||
elif section_must_be_skipped:
|
||||
components.append(dummy_component)
|
||||
else:
|
||||
component = create_setting_component(k)
|
||||
component_dict[k] = component
|
||||
|
|
@ -1639,19 +1564,19 @@ Requested path was: {f}
|
|||
reload_script_bodies.click(
|
||||
fn=reload_scripts,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='function(){}'
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
def request_restart():
|
||||
shared.state.interrupt()
|
||||
settings_interface.gradio_ref.do_restart = True
|
||||
shared.state.need_restart = True
|
||||
|
||||
restart_gradio.click(
|
||||
|
||||
fn=request_restart,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='function(){restart_reload()}'
|
||||
_js='restart_reload'
|
||||
)
|
||||
|
||||
if column is not None:
|
||||
|
|
@ -1666,10 +1591,6 @@ Requested path was: {f}
|
|||
(train_interface, "Train", "ti"),
|
||||
]
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
|
||||
interfaces += [(settings_interface, "Settings", "settings")]
|
||||
|
||||
css = ""
|
||||
|
||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||
|
|
@ -1686,13 +1607,20 @@ Requested path was: {f}
|
|||
if not cmd_opts.no_progressbar_hiding:
|
||||
css += css_hide_progressbar
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
interfaces += [(settings_interface, "Settings", "settings")]
|
||||
|
||||
extensions_interface = ui_extensions.create_ui()
|
||||
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Row(elem_id="quicksettings"):
|
||||
for i, k, item in quicksettings_list:
|
||||
component = create_setting_component(k, is_quicksettings=True)
|
||||
component_dict[k] = component
|
||||
|
||||
settings_interface.gradio_ref = demo
|
||||
parameters_copypaste.integrate_settings_paste_fields(component_dict)
|
||||
parameters_copypaste.run_bind()
|
||||
|
||||
with gr.Tabs(elem_id="tabs") as tabs:
|
||||
for interface, label, ifid in interfaces:
|
||||
|
|
@ -1747,85 +1675,6 @@ Requested path was: {f}
|
|||
component_dict['sd_model_checkpoint'],
|
||||
]
|
||||
)
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
|
||||
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
|
||||
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
|
||||
send_to_img2img.click(
|
||||
fn=lambda img, *args: (image_from_url_text(img),*args),
|
||||
_js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
|
||||
inputs=[txt2img_gallery] + txt2img_fields,
|
||||
outputs=[init_img] + img2img_fields,
|
||||
)
|
||||
|
||||
send_to_inpaint.click(
|
||||
fn=lambda x, *args: (image_from_url_text(x), *args),
|
||||
_js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
|
||||
inputs=[txt2img_gallery] + txt2img_fields,
|
||||
outputs=[init_img_with_mask] + img2img_fields,
|
||||
)
|
||||
|
||||
img2img_send_to_img2img.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_img2img",
|
||||
inputs=[img2img_gallery],
|
||||
outputs=[init_img],
|
||||
)
|
||||
|
||||
img2img_send_to_inpaint.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_inpaint",
|
||||
inputs=[img2img_gallery],
|
||||
outputs=[init_img_with_mask],
|
||||
)
|
||||
|
||||
send_to_extras.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_extras",
|
||||
inputs=[txt2img_gallery],
|
||||
outputs=[extras_image],
|
||||
)
|
||||
|
||||
open_txt2img_folder.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
open_img2img_folder.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
open_extras_folder.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
img2img_send_to_extras.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_extras",
|
||||
inputs=[img2img_gallery],
|
||||
outputs=[extras_image],
|
||||
)
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
}
|
||||
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
|
||||
modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
|
||||
|
||||
ui_config_file = cmd_opts.ui_config_file
|
||||
ui_settings = {}
|
||||
|
|
@ -1845,7 +1694,7 @@ Requested path was: {f}
|
|||
def apply_field(obj, field, condition=None, init_field=None):
|
||||
key = path + "/" + field
|
||||
|
||||
if getattr(obj,'custom_script_source',None) is not None:
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = 'customscript/' + obj.custom_script_source + '/' + key
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
|
|
@ -1905,7 +1754,7 @@ def load_javascript(raw_response):
|
|||
javascript = f'<script>{jsfile.read()}</script>'
|
||||
|
||||
scripts_list = modules.scripts.list_scripts("javascript", ".js")
|
||||
|
||||
|
||||
for basedir, filename, path in scripts_list:
|
||||
with open(path, "r", encoding="utf8") as jsfile:
|
||||
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
||||
|
|
@ -1926,4 +1775,3 @@ def load_javascript(raw_response):
|
|||
|
||||
|
||||
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
|
||||
reload_javascript()
|
||||
|
|
|
|||
268
modules/ui_extensions.py
Normal file
268
modules/ui_extensions.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
import gradio as gr
|
||||
import html
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
|
||||
|
||||
available_extensions = {"extensions": []}
|
||||
|
||||
|
||||
def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
|
||||
|
||||
|
||||
def apply_and_restart(disable_list, update_list):
|
||||
check_access()
|
||||
|
||||
disabled = json.loads(disable_list)
|
||||
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||
|
||||
update = json.loads(update_list)
|
||||
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
||||
|
||||
update = set(update)
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.name not in update:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.pull()
|
||||
except Exception:
|
||||
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
|
||||
|
||||
def check_updates():
|
||||
check_access()
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.remote is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
except Exception:
|
||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return extension_table()
|
||||
|
||||
|
||||
def extension_table():
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>URL</th>
|
||||
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.can_update:
|
||||
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
||||
else:
|
||||
ext_status = ext.status
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
code += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def normalize_git_url(url):
|
||||
if url is None:
|
||||
return ""
|
||||
|
||||
url = url.replace(".git", "")
|
||||
return url
|
||||
|
||||
|
||||
def install_extension_from_url(dirname, url):
|
||||
check_access()
|
||||
|
||||
assert url, 'No URL specified'
|
||||
|
||||
if dirname is None or dirname == "":
|
||||
*parts, last_part = url.split('/')
|
||||
last_part = normalize_git_url(last_part)
|
||||
|
||||
dirname = last_part
|
||||
|
||||
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
||||
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||
|
||||
normalized_url = normalize_git_url(url)
|
||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
repo = git.Repo.clone_from(url, tmpdir)
|
||||
repo.remote().fetch()
|
||||
|
||||
os.rename(tmpdir, target_dir)
|
||||
|
||||
extensions.list_extensions()
|
||||
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def install_extension_from_index(url):
|
||||
ext_table, message = install_extension_from_url(None, url)
|
||||
|
||||
return refresh_available_extensions_from_data(), ext_table, message
|
||||
|
||||
|
||||
def refresh_available_extensions(url):
|
||||
global available_extensions
|
||||
|
||||
import urllib.request
|
||||
with urllib.request.urlopen(url) as response:
|
||||
text = response.read()
|
||||
|
||||
available_extensions = json.loads(text)
|
||||
|
||||
return url, refresh_available_extensions_from_data(), ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data():
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="available_extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Extension</th>
|
||||
<th>Description</th>
|
||||
<th>Action</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
for ext in extlist:
|
||||
name = ext.get("name", "noname")
|
||||
url = ext.get("url", None)
|
||||
description = ext.get("description", "")
|
||||
|
||||
if url is None:
|
||||
continue
|
||||
|
||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
|
||||
<td>{html.escape(description)}</td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
code += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
|
||||
with gr.Row():
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
_js="extensions_apply",
|
||||
inputs=[extensions_disabled_list, extensions_update_list],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
check.click(
|
||||
fn=check_updates,
|
||||
_js="extensions_check",
|
||||
inputs=[],
|
||||
outputs=[extensions_table],
|
||||
)
|
||||
|
||||
with gr.TabItem("Available"):
|
||||
with gr.Row():
|
||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
|
||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index],
|
||||
outputs=[available_extensions_index, available_extensions_table, install_result],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[extension_to_install],
|
||||
outputs=[available_extensions_table, extensions_table, install_result],
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
install_button = gr.Button(value="Install", variant="primary")
|
||||
install_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
install_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
inputs=[install_dirname, install_url],
|
||||
outputs=[extensions_table, install_result],
|
||||
)
|
||||
|
||||
return ui
|
||||
|
|
@ -10,6 +10,7 @@ import modules.shared
|
|||
from modules import modelloader, shared
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ class Upscaler:
|
|||
dest_w = img.width * scale
|
||||
dest_h = img.height * scale
|
||||
for i in range(3):
|
||||
if img.width >= dest_w and img.height >= dest_h:
|
||||
if img.width > dest_w and img.height > dest_h:
|
||||
break
|
||||
img = self.do_upscale(img, selected_model)
|
||||
if img.width != dest_w or img.height != dest_h:
|
||||
|
|
@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler):
|
|||
self.name = "Lanczos"
|
||||
self.scalers = [UpscalerData("Lanczos", None, self)]
|
||||
|
||||
|
||||
class UpscalerNearest(Upscaler):
|
||||
scalers = []
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
|
||||
|
||||
def load_model(self, _):
|
||||
pass
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.name = "Nearest"
|
||||
self.scalers = [UpscalerData("Nearest", None, self)]
|
||||
Loading…
Add table
Add a link
Reference in a new issue