From a63cf10650f52f60044cb9b72f7292166959cc62 Mon Sep 17 00:00:00 2001 From: arrmansa <41120982+arrmansa@users.noreply.github.com> Date: Mon, 30 Dec 2024 23:33:43 +0530 Subject: [PATCH] Update img2imgalt.py WIP --- scripts/img2imgalt.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 109c4a2ab..fa0612aaa 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -49,7 +49,12 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): t = dnw.sigma_to_t(sigma_in) if shared.sd_model.is_sdxl: - eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} ) + num_classes_hack = shared.sd_model.model.diffusion_model.num_classes + shared.sd_model.model.diffusion_model.num_classes = None + try: + eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} ) + finally: + shared.sd_model.model.diffusion_model.num_classes = num_classes_hack else: eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) @@ -78,13 +83,6 @@ Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "origina # Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736 def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): - if shared.sd_model.is_sdxl: - cond_tensor = cond['crossattn'] - uncond_tensor = uncond['crossattn'] - cond_in = torch.cat([uncond_tensor, cond_tensor]) - else: - cond_in = torch.cat([uncond, cond]) - x = p.init_latent s_in = x.new_ones([x.shape[0]]) @@ -124,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): if shared.sd_model.is_sdxl: - eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} ) + num_classes_hack = shared.sd_model.model.diffusion_model.num_classes + shared.sd_model.model.diffusion_model.num_classes = None + try: + eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} ) + finally: + shared.sd_model.model.diffusion_model.num_classes = num_classes_hack else: eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) @@ -211,9 +214,19 @@ class Script(scripts.Script): and self.cache.sigma_adjustment == sigma_adjustment same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 + rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) + if same_everything: rec_noise = self.cache.noise else: + # This prevents a crash, because I don't know how to access the underlying .diffusion_model yet when controlnet is enabled. WIP + # modules.sd_unet -> we're good + # scripts.hook -> we're cooked + if "scripts.hook" in str(shared.sd_model.model.diffusion_model.forward.__module__): + print("turn off any controlnets, do 1 pass and then turn controlnet back on to cache noise") + p.steps = 1 + return sd_samplers.create_sampler(p.sampler_name, p.sd_model).sample_img2img(p, p.init_latent, rand_noise, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) + shared.state.job_count += 1 cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt]) @@ -223,8 +236,6 @@ class Script(scripts.Script): rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) - rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) - combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)