From 5b80eb8dac50941c05209df9bb560959ab81bdb0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 4 Mar 2023 08:17:31 +0100 Subject: Pipeline: Improved initial image generation --- .../stable_diffusion/vlpn_stable_diffusion.py | 49 ++++++++++++---------- 1 file changed, 26 insertions(+), 23 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 242be29..2251848 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -295,16 +295,14 @@ class VlpnStableDiffusion(DiffusionPipeline): def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep + offset, 0) + t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] timesteps = timesteps.to(device) - return timesteps + return timesteps, num_inference_steps - t_start def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): offset = (max_offset * (2 * torch.rand( @@ -312,12 +310,16 @@ class VlpnStableDiffusion(DiffusionPipeline): dtype=dtype, device=device, generator=generator - ) - 1)).expand(batch_size, 3, width, height) - image = (.1 * torch.normal( - mean=offset, - std=1, - generator=generator - )).clamp(-1, 1) + ) - 1)).expand(batch_size, 1, 2, 2) + image = F.interpolate( + torch.normal( + mean=offset, + std=0.3, + generator=generator + ).clamp(-1, 1), + size=(width, height), + mode="bicubic" + ).expand(batch_size, 3, width, height) return image def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): @@ -382,7 +384,7 @@ class VlpnStableDiffusion(DiffusionPipeline): eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - max_image_offset: float = 1.0, + max_init_offset: float = 0.7, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -464,11 +466,7 @@ class VlpnStableDiffusion(DiffusionPipeline): device ) - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) - - # 5. Prepare latent variables + # 4. Prepare latent variables if isinstance(image, PIL.Image.Image): image = preprocess(image) elif image is None: @@ -476,13 +474,18 @@ class VlpnStableDiffusion(DiffusionPipeline): batch_size * num_images_per_prompt, width, height, - max_image_offset, + max_init_offset, prompt_embeds.dtype, device, generator ) + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, @@ -492,10 +495,10 @@ class VlpnStableDiffusion(DiffusionPipeline): generator ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Denoising loo + # 8. Denoising loo if do_self_attention_guidance: store_processor = CrossAttnStoreProcessor() self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor @@ -559,13 +562,13 @@ class VlpnStableDiffusion(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing + # 9. Post-processing image = self.decode_latents(latents) - # 9. Run safety checker + # 10. Run safety checker has_nsfw_concept = None - # 10. Convert to PIL + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) -- cgit v1.2.3-54-g00ecf