Live previews run on cudastream

This commit is contained in:
drhead 2024-05-18 19:56:55 -04:00 committed by GitHub
parent ddb28b33a3
commit 387bcd8e4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -59,15 +59,14 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
return x_sample
lp_stream = torch.cuda.Stream()
def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
with torch.cuda.stream(lp_stream):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * x_sample.permute(1, 2, 0)
x_sample = x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=True)
return x_sample
def decode_first_stage(model, x):
@ -81,7 +80,9 @@ def sample_to_image(samples, index=0, approximation=None):
def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
sample_tensors = [single_sample_to_image(sample, approximation) for sample in samples]
lp_stream.synchronize()
return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors])
def images_tensor_to_samples(image, approximation=None, model=None):