mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-12-15 14:50:55 -08:00
initial SD3 support
This commit is contained in:
parent
a7116aa9a1
commit
5b2a60b8e2
14 changed files with 333 additions and 44 deletions
|
|
@ -1,7 +1,9 @@
|
|||
import collections
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import re
|
||||
|
|
@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
|
|||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
from modules.shared import opts
|
||||
|
|
@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
|||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
SD1 = 1
|
||||
SD2 = 2
|
||||
SDXL = 3
|
||||
SSD = 4
|
||||
SD3 = 5
|
||||
|
||||
|
||||
def replace_key(d, key, new_key, value):
|
||||
keys = list(d.keys())
|
||||
|
||||
|
|
@ -368,6 +376,36 @@ def check_fp8(model):
|
|||
return enable_fp8
|
||||
|
||||
|
||||
def set_model_type(model, state_dict):
|
||||
model.is_sd1 = False
|
||||
model.is_sd2 = False
|
||||
model.is_sdxl = False
|
||||
model.is_ssd = False
|
||||
model.is_ssd3 = False
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
|
||||
model.is_sd3 = True
|
||||
model.model_type = ModelType.SD3
|
||||
elif hasattr(model, 'conditioner'):
|
||||
model.is_sdxl = True
|
||||
|
||||
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
|
||||
model.is_ssd = True
|
||||
model.model_type = ModelType.SSD
|
||||
else:
|
||||
model.model_type = ModelType.SDXL
|
||||
elif hasattr(model.cond_stage_model, 'model'):
|
||||
model.is_sd2 = True
|
||||
model.model_type = ModelType.SD2
|
||||
else:
|
||||
model.is_sd1 = True
|
||||
model.model_type = ModelType.SD1
|
||||
|
||||
|
||||
def set_model_fields(model):
|
||||
if not hasattr(model, 'latent_channels'):
|
||||
model.latent_channels = 4
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
|
@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||
set_model_type(model, state_dict)
|
||||
set_model_fields(model)
|
||||
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
|
|
@ -552,8 +589,7 @@ def patch_given_betas():
|
|||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
|
||||
|
||||
def repair_config(sd_config):
|
||||
|
||||
def repair_config(sd_config, state_dict=None):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
|
|
@ -563,8 +599,9 @@ def repair_config(sd_config):
|
|||
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
|
||||
# For UnCLIP-L, override the hardcoded karlo directory
|
||||
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
||||
|
|
@ -580,6 +617,7 @@ def repair_config(sd_config):
|
|||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
||||
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
|
|
@ -715,6 +753,25 @@ def send_model_to_trash(m):
|
|||
devices.torch_gc()
|
||||
|
||||
|
||||
def instantiate_from_config(config, state_dict=None):
|
||||
constructor = get_obj_from_str(config["target"])
|
||||
|
||||
params = {**config.get("params", {})}
|
||||
|
||||
if state_dict and "state_dict" in params and params["state_dict"] is None:
|
||||
params["state_dict"] = state_dict
|
||||
|
||||
return constructor(**params)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
|
@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
repair_config(sd_config, state_dict)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
|
|
@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
sd_model = instantiate_from_config(sd_config.model, state_dict)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "creating model quickly", full_traceback=True)
|
||||
|
|
@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
sd_model = instantiate_from_config(sd_config.model, state_dict)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
|
|
@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if hasattr(sd_model, "after_load_weights"):
|
||||
sd_model.after_load_weights()
|
||||
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue