From f23fd5184b8ba4ec04506495f4a61726e50756f7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 17:38:44 +0200 Subject: Small perf improvements --- dreambooth.py | 89 ++++++++++++++++++++++++++++++----------------------------- 1 file changed, 46 insertions(+), 43 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 5fbf172..9d6b8d6 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -13,7 +13,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from pipelines.stable_diffusion.no_check import NoCheck @@ -30,6 +30,9 @@ from data.dreambooth.prompt import PromptDataset logger = get_logger(__name__) +torch.backends.cuda.matmul.allow_tf32 = True + + def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." @@ -346,7 +349,7 @@ class Checkpointer: print("Saving model...") unwrapped = self.accelerator.unwrap_model(self.unet) - pipeline = StableDiffusionPipeline( + pipeline = VlpnStableDiffusion( text_encoder=self.text_encoder, vae=self.vae, unet=self.accelerator.unwrap_model(self.unet), @@ -354,8 +357,6 @@ class Checkpointer: scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), - safety_checker=NoCheck(), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.enable_attention_slicing() pipeline.save_pretrained(f"{self.output_dir}/model") @@ -381,7 +382,6 @@ class Checkpointer: unet=unwrapped, tokenizer=self.tokenizer, scheduler=scheduler, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() @@ -459,44 +459,6 @@ def main(): if args.seed is not None: set_seed(args.seed) - if args.with_prior_preservation: - class_images_dir = Path(args.class_data_dir) - class_images_dir.mkdir(parents=True, exist_ok=True) - cur_class_images = len(list(class_images_dir.iterdir())) - - if cur_class_images < args.num_class_images: - torch_dtype = torch.float32 - if accelerator.device.type == "cuda": - torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] - - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, torch_dtype=torch_dtype) - pipeline.enable_attention_slicing() - pipeline.set_progress_bar_config(disable=True) - pipeline.to(accelerator.device) - - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - - sample_dataloader = accelerator.prepare(sample_dataloader) - - for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process - ): - with accelerator.autocast(): - images = pipeline(example["prompt"]).images - - for i, image in enumerate(images): - image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -526,6 +488,47 @@ def main(): freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) + # Generate class images, if necessary + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + class_images_dir.mkdir(parents=True, exist_ok=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + + for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): + with accelerator.autocast(): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * -- cgit v1.2.3-54-g00ecf