From 6b8a93f46f053668c8023520225a18445d48d8f1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 25 Mar 2023 16:34:48 +0100 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 61 ++++++++++++---------- 1 file changed, 33 insertions(+), 28 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ea2a656..127ca50 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -307,39 +307,45 @@ class VlpnStableDiffusion(DiffusionPipeline): return timesteps, num_inference_steps - t_start - def prepare_image(self, batch_size, width, height, dtype, device, generator=None): - return (1.4 * perlin_noise( + def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): + offset_image = perlin_noise( (batch_size, 1, width, height), res=1, - octaves=4, generator=generator, dtype=dtype, device=device - )).clamp(-1, 1).expand(batch_size, 3, width, height) + ) + offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) + offset_latents = self.vae.config.scaling_factor * offset_latents + return offset_latents - def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): + def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) - init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) - init_latents = self.vae.config.scaling_factor * init_latents + latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) + latents = self.vae.config.scaling_factor * latents - if batch_size % init_latents.shape[0] != 0: + if batch_size % latents.shape[0] != 0: raise ValueError( - f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + f"Cannot duplicate `init_image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: - batch_multiplier = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) + batch_multiplier = batch_size // latents.shape[0] + latents = torch.cat([latents] * batch_multiplier, dim=0) # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) + + if brightness_offset != 0: + noise += brightness_offset * self.prepare_brightness_offset( + batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator + ) # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents + latents = self.scheduler.add_noise(latents, noise, timestep) return latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -352,6 +358,11 @@ class VlpnStableDiffusion(DiffusionPipeline): else: latents = latents.to(device) + if brightness_offset != 0: + latents += brightness_offset * self.prepare_brightness_offset( + batch_size, height, width, dtype, device, generator + ) + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -395,7 +406,8 @@ class VlpnStableDiffusion(DiffusionPipeline): sag_scale: float = 0.75, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + brightness_offset: Union[float, torch.FloatTensor] = 0, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -468,7 +480,7 @@ class VlpnStableDiffusion(DiffusionPipeline): num_channels_latents = self.unet.in_channels do_classifier_free_guidance = guidance_scale > 1.0 do_self_attention_guidance = sag_scale > 0.0 - prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" + prep_from_image = isinstance(image, PIL.Image.Image) # 3. Encode input prompt prompt_embeds = self.encode_prompt( @@ -482,15 +494,6 @@ class VlpnStableDiffusion(DiffusionPipeline): # 4. Prepare latent variables if isinstance(image, PIL.Image.Image): image = preprocess(image) - elif image == "noise": - image = self.prepare_image( - batch_size * num_images_per_prompt, - width, - height, - prompt_embeds.dtype, - device, - generator - ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -503,9 +506,10 @@ class VlpnStableDiffusion(DiffusionPipeline): image, latent_timestep, batch_size * num_images_per_prompt, + brightness_offset, prompt_embeds.dtype, device, - generator + generator, ) else: latents = self.prepare_latents( @@ -513,10 +517,11 @@ class VlpnStableDiffusion(DiffusionPipeline): num_channels_latents, height, width, + brightness_offset, prompt_embeds.dtype, device, generator, - image + image, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -- cgit v1.2.3-54-g00ecf