From 329ad48b307e782b0e23fce80ae9087a4003e73d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 30 Nov 2022 14:02:35 +0100 Subject: Update --- dreambooth.py | 19 +++++++-------- .../stable_diffusion/vlpn_stable_diffusion.py | 27 +++++++++++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 49d4447..3dd0920 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -115,7 +115,7 @@ def parse_args(): parser.add_argument( "--resolution", type=int, - default=512, + default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" @@ -267,7 +267,7 @@ def parse_args(): parser.add_argument( "--sample_image_size", type=int, - default=512, + default=768, help="Size of sample images", ) parser.add_argument( @@ -297,7 +297,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=25, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -459,7 +459,7 @@ class Checkpointer: torch.cuda.empty_cache() @torch.no_grad() - def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") unwrapped_unet = self.accelerator.unwrap_model( @@ -474,13 +474,14 @@ class Checkpointer: scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) + pipeline.enable_vae_slicing() train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), device=pipeline.device, generator=generator, ) @@ -875,9 +876,7 @@ def main(): ) if accelerator.is_main_process: - checkpointer.save_samples( - 0, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + checkpointer.save_samples(0, args.sample_steps) local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), @@ -1089,9 +1088,7 @@ def main(): max_acc_val = avg_acc_val if sample_checkpoint and accelerator.is_main_process: - checkpointer.save_samples( - global_step, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + checkpointer.save_samples(global_step, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 85b0216..c77c4d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -20,6 +20,7 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.models.vae import DecoderOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer from models.clip.prompt import PromptProcessor @@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) self.prompt_processor = PromptProcessor(tokenizer, text_encoder) + self.use_slicing = False self.register_modules( vae=vae, @@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.use_slicing = False + @property def execution_device(self): r""" @@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = self.vae_decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + def vae_decode(self, latents): + if self.use_slicing: + decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] + decoded = torch.cat(decoded_slices) + return DecoderOutput(sample=decoded) + else: + return self.vae.decode(latents) + @torch.no_grad() def __call__( self, -- cgit v1.2.3-54-g00ecf