From 89d471652644f449966a0cd944041c98dab7f66c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 07:25:24 +0100 Subject: Code deduplication --- .../stable_diffusion/vlpn_stable_diffusion.py | 32 ++++++---------------- 1 file changed, 9 insertions(+), 23 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb300d1..6bc40e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -20,7 +20,7 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import logging +from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer from models.clip.prompt import PromptProcessor @@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline): return timesteps - def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline): ) if latents is None: - rand_device = "cpu" if device.type == "mps" else device - - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) @@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline): f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * batch_size, dim=0) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) @@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = self.prepare_latents_from_image( image, latent_timestep, - batch_size, - num_images_per_prompt, + batch_size * num_images_per_prompt, text_embeddings.dtype, device, generator ) else: latents = self.prepare_latents( - batch_size, - num_images_per_prompt, + batch_size * num_images_per_prompt, num_channels_latents, height, width, -- cgit v1.2.3-54-g00ecf