mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-12-06 02:30:30 -08:00
allow baking in VAE in checkpoint merger tab
do not save config if it's the default for checkpoint merger tab change file naming scheme for checkpoint merger tab allow just saving A without any merging for checkpoint merger tab some stylistic changes for UI in checkpoint merger tab
This commit is contained in:
parent
c7e50425f6
commit
0f5dbfffd0
7 changed files with 102 additions and 59 deletions
|
|
@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
|
|||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models, sd_samplers
|
||||
from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
|
|
@ -251,7 +251,8 @@ def run_pnginfo(image):
|
|||
|
||||
def create_config(ckpt_result, config_source, a, b, c):
|
||||
def config(x):
|
||||
return sd_models.find_checkpoint_config(x) if x else None
|
||||
res = sd_models.find_checkpoint_config(x) if x else None
|
||||
return res if res != shared.sd_default_config else None
|
||||
|
||||
if config_source == 0:
|
||||
cfg = config(a) or config(b) or config(c)
|
||||
|
|
@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c):
|
|||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
shared.state.job_count = 1
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
|
|
@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
def add_difference(theta0, theta1_2_diff, alpha):
|
||||
return theta0 + (alpha * theta1_2_diff)
|
||||
|
||||
def filename_weighed_sum():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
Ma = round(1 - multiplier, 2)
|
||||
Mb = round(multiplier, 2)
|
||||
|
||||
return f"{Ma}({a}) + {Mb}({b})"
|
||||
|
||||
def filename_add_differnece():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
c = tertiary_model_info.model_name
|
||||
M = round(multiplier, 2)
|
||||
|
||||
return f"{a} + {M}({b} - {c})"
|
||||
|
||||
def filename_nothing():
|
||||
return primary_model_info.model_name
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": (filename_weighed_sum, None, weighted_sum),
|
||||
"Add difference": (filename_add_differnece, get_difference, add_difference),
|
||||
"No interpolation": (filename_nothing, None, None),
|
||||
}
|
||||
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
||||
|
||||
if not primary_model_name:
|
||||
return fail("Failed: Merging requires a primary model.")
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
|
||||
if not secondary_model_name:
|
||||
if theta_func2 and not secondary_model_name:
|
||||
return fail("Failed: Merging requires a secondary model.")
|
||||
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": (None, weighted_sum),
|
||||
"Add difference": (get_difference, add_difference),
|
||||
}
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
||||
|
||||
if theta_func1 and not tertiary_model_name:
|
||||
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
||||
|
||||
|
||||
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
||||
|
||||
result_is_inpainting_model = False
|
||||
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
if theta_func2:
|
||||
shared.state.textinfo = f"Loading B"
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
else:
|
||||
theta_1 = None
|
||||
|
||||
if theta_func1:
|
||||
shared.state.job_count += 1
|
||||
|
||||
shared.state.textinfo = f"Loading C"
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
shared.state.textinfo = 'Merging B and C'
|
||||
shared.state.sampling_steps = len(theta_1.keys())
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if key in chckpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||
|
|
@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
print("Merging...")
|
||||
|
||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
shared.state.textinfo = 'Merging A and B'
|
||||
shared.state.sampling_steps = len(theta_0.keys())
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
if theta_1 and 'model' in key and key in theta_1:
|
||||
|
||||
if key in chckpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
|
@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
|
|
@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
|
||||
shared.state.sampling_step += 1
|
||||
|
||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
|
||||
if key in chckpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
theta_0[key] = theta_1[key]
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
del theta_1
|
||||
|
||||
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
||||
if bake_in_vae_filename is not None:
|
||||
print(f"Baking in VAE from {bake_in_vae_filename}")
|
||||
shared.state.textinfo = 'Baking in VAE'
|
||||
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
||||
|
||||
for key in vae_dict.keys():
|
||||
theta_0_key = 'first_stage_model.' + key
|
||||
if theta_0_key in theta_0:
|
||||
theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
|
||||
|
||||
del vae_dict
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = \
|
||||
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
|
||||
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
|
||||
interp_method.replace(" ", "_") + \
|
||||
'-merged.' + \
|
||||
("inpainting." if result_is_inpainting_model else "") + \
|
||||
checkpoint_format
|
||||
|
||||
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
|
||||
filename = filename_generator() if custom_name == '' else custom_name
|
||||
filename += ".inpainting" if result_is_inpainting_model else ""
|
||||
filename += "." + checkpoint_format
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
|
|
@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
|
||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
print(f"Checkpoint saved to {output_modelname}.")
|
||||
shared.state.textinfo = "Checkpoint saved"
|
||||
shared.state.end()
|
||||
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue