diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 69bf9be6d..8bbd8caba 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -58,14 +58,11 @@ def samples_to_images_tensor(sample, approximation=None, model=None): return x_sample -lp_stream = torch.cuda.Stream() -def single_sample_to_image(sample, approximation=None): - with torch.cuda.stream(lp_stream): - x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 - x_sample = torch.clamp(x_sample, min=0.0, max=1.0) - x_sample = 255. * x_sample.permute(1, 2, 0) - x_sample = x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=True) - return x_sample +def single_sample_to_image(sample, approximation=None, non_blocking=False): + x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 + x_sample = torch.clamp(x_sample, min=0.0, max=1.0) + x_sample = 255. * x_sample.permute(1, 2, 0) + return x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=non_blocking) def decode_first_stage(model, x): @@ -78,9 +75,18 @@ def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) +if torch.cuda.is_available(): + lp_stream = torch.cuda.Stream() + live_preview_stream_context = torch.cuda.stream(lp_stream) +else: + lp_stream = None + live_preview_stream_context = nullcontext() + def samples_to_image_grid(samples, approximation=None): - sample_tensors = [single_sample_to_image(sample, approximation) for sample in samples] - lp_stream.synchronize() + with live_preview_stream_context: + sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=True) for sample in samples] + if lp_stream is not None: + lp_stream.synchronize() return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors])