From 2ad46871e2ead985445da2848a4eb7072b6e48aa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 14 Nov 2022 17:09:58 +0100 Subject: Update --- dreambooth.py | 71 ++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 41 insertions(+), 30 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 8c4bf50..7b34fce 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -15,7 +15,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image @@ -23,7 +23,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule @@ -144,7 +143,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=6000, + default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -211,7 +210,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=7 / 8 + default=6/7 ) parser.add_argument( "--ema_max_decay", @@ -283,6 +282,12 @@ def parse_args(): default=1, help="Number of samples to generate per batch", ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) parser.add_argument( "--train_batch_size", type=int, @@ -292,7 +297,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=30, + default=25, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -461,7 +466,7 @@ class Checkpointer: self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAncestralDiscreteScheduler( + scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -487,23 +492,30 @@ class Checkpointer: with torch.inference_mode(): for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt.format(identifier=self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [ - prompt.format(identifier=self.instance_identifier) - for batch in batches - for prompt in batch["prompts"] - ][:self.sample_batch_size] - nprompt = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ][:self.sample_batch_size] + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] samples = pipeline( prompt=prompt, @@ -523,7 +535,7 @@ class Checkpointer: del samples image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid.save(file_path, quality=85) del all_samples del image_grid @@ -576,6 +588,12 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + unet.set_use_memory_efficient_attention_xformers(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + ema_unet = None if args.use_ema: ema_unet = EMAModel( @@ -586,12 +604,6 @@ def main(): device=accelerator.device ) - unet.set_use_memory_efficient_attention_xformers(True) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - # Freeze text_encoder and vae freeze_params(vae.parameters()) @@ -726,7 +738,7 @@ def main(): size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.sample_batches, + valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) @@ -743,7 +755,7 @@ def main(): for i in range(0, len(missing_data), args.sample_batch_size) ] - scheduler = EulerAncestralDiscreteScheduler( + scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -962,6 +974,8 @@ def main(): optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() + if args.use_ema: + ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) loss = loss.detach().item() @@ -969,9 +983,6 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - if args.use_ema: - ema_unet.step(unet) - local_progress_bar.update(1) global_progress_bar.update(1) -- cgit v1.2.3-54-g00ecf