Only use cudastreams for live preview when cuda available

This commit is contained in:
drhead 2024-05-18 20:50:38 -04:00 committed by GitHub
parent 72c5966e48
commit 044494d914
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -58,14 +58,11 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
return x_sample
lp_stream = torch.cuda.Stream()
def single_sample_to_image(sample, approximation=None):
with torch.cuda.stream(lp_stream):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * x_sample.permute(1, 2, 0)
x_sample = x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=True)
return x_sample
def single_sample_to_image(sample, approximation=None, non_blocking=False):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * x_sample.permute(1, 2, 0)
return x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=non_blocking)
def decode_first_stage(model, x):
@ -78,9 +75,18 @@ def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
if torch.cuda.is_available():
lp_stream = torch.cuda.Stream()
live_preview_stream_context = torch.cuda.stream(lp_stream)
else:
lp_stream = None
live_preview_stream_context = nullcontext()
def samples_to_image_grid(samples, approximation=None):
sample_tensors = [single_sample_to_image(sample, approximation) for sample in samples]
lp_stream.synchronize()
with live_preview_stream_context:
sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=True) for sample in samples]
if lp_stream is not None:
lp_stream.synchronize()
return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors])