From a8a5abae42f6f42056cc27e0cf5313aab080c3a7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 22:16:13 +0200 Subject: Various improvements, added inference script --- dreambooth.py | 98 +++++++++++++++++++++++++++++++---------------------------- 1 file changed, 52 insertions(+), 46 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index c01cbe3..bc7a472 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -1,5 +1,4 @@ import argparse -import itertools import math import os import datetime @@ -61,7 +60,8 @@ def parse_args(): "--repeats", type=int, default=100, - help="How many times to repeat the training data.") + help="How many times to repeat the training data." + ) parser.add_argument( "--output_dir", type=str, @@ -72,7 +72,8 @@ def parse_args(): "--seed", type=int, default=None, - help="A seed for reproducible training.") + help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -94,7 +95,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=1000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -184,7 +185,7 @@ def parse_args(): parser.add_argument( "--checkpoint_frequency", type=int, - default=500, + default=200, help="How often to save a checkpoint and sample image", ) parser.add_argument( @@ -220,7 +221,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=50, + default=80, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -381,7 +382,6 @@ class Checkpointer: def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = f"{self.output_dir}/samples/{mode}" os.makedirs(samples_path, exist_ok=True) - checker = NoCheck() unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( @@ -507,6 +507,7 @@ def main(): torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype) + pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images @@ -589,7 +590,11 @@ def main(): # TODO (patil-suraj): laod scheduler using args noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + tensor_format="pt" ) def collate_fn(examples): @@ -709,7 +714,7 @@ def main(): args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) - local_progress_bar.set_description("Steps") + local_progress_bar.set_description("Steps ") progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Global steps") @@ -723,31 +728,31 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - with accelerator.autocast(): - # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) @@ -766,12 +771,12 @@ def main(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) - + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss = loss.detach().item() train_loss += loss @@ -804,7 +809,7 @@ def main(): val_loss = 0.0 for step, batch in enumerate(val_dataloader): - with torch.no_grad(), accelerator.autocast(): + with torch.no_grad(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -822,18 +827,19 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - if args.with_prior_preservation: - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + with accelerator.autocast(): + if args.with_prior_preservation: + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, + reduction="none").mean([1, 2, 3]).mean() - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf