From 220806dbd21da3ba83c14096225c31824dfe81df Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 3 Mar 2023 22:09:24 +0100 Subject: Removed offset noise from training, added init offset to pipeline --- .../stable_diffusion/vlpn_stable_diffusion.py | 95 +++++++++------------- 1 file changed, 39 insertions(+), 56 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb09fe1..c4f7401 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -293,53 +293,39 @@ class VlpnStableDiffusion(DiffusionPipeline): return prompt_embeds - def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): - if latents_are_image: - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:] - else: - timesteps = self.scheduler.timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] timesteps = timesteps.to(device) return timesteps - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + def prepare_image(self, batch_size, width, height, dtype, device, generator=None): + return torch.randn( + (batch_size, 1, 1, 1), + dtype=dtype, + device=device, + generator=generator + ).expand(batch_size, 3, width, height) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device=device, dtype=dtype) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - - return latents - - def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): + def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents + init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) + init_latents = self.vae.config.scaling_factor * init_latents - if batch_size > init_latents.shape[0]: + if batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * batch_size, dim=0) + batch_multiplier = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) @@ -368,7 +354,7 @@ class VlpnStableDiffusion(DiffusionPipeline): return extra_step_kwargs def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -381,7 +367,7 @@ class VlpnStableDiffusion(DiffusionPipeline): prompt: Union[str, List[str], List[int], List[List[int]]], negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, num_images_per_prompt: int = 1, - strength: float = 0.8, + strength: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -461,7 +447,6 @@ class VlpnStableDiffusion(DiffusionPipeline): device = self.execution_device do_classifier_free_guidance = guidance_scale > 1.0 do_self_attention_guidance = sag_scale > 0.0 - latents_are_image = isinstance(image, PIL.Image.Image) # 3. Encode input prompt prompt_embeds = self.encode_prompt( @@ -474,33 +459,31 @@ class VlpnStableDiffusion(DiffusionPipeline): # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) + timesteps = self.get_timesteps(num_inference_steps, strength, device) # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - if latents_are_image: + if isinstance(image, PIL.Image.Image): image = preprocess(image) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - latents = self.prepare_latents_from_image( - image, - latent_timestep, + elif image is None: + image = self.prepare_image( batch_size * num_images_per_prompt, - prompt_embeds.dtype, - device, - generator - ) - else: - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, width, + height, prompt_embeds.dtype, device, - generator, - image, + generator ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator + ) + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) -- cgit v1.2.3-54-g00ecf