summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 22:16:13 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 22:16:13 +0200
commita8a5abae42f6f42056cc27e0cf5313aab080c3a7 (patch)
tree32c163bbc58aa2f827c5ba5108df81dc14fbe130 /dreambooth.py
parentIncorporate upstream changes (diff)
downloadtextual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.gz
textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.bz2
textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.zip
Various improvements, added inference script
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py98
1 files changed, 52 insertions, 46 deletions
diff --git a/dreambooth.py b/dreambooth.py
index c01cbe3..bc7a472 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -1,5 +1,4 @@
1import argparse 1import argparse
2import itertools
3import math 2import math
4import os 3import os
5import datetime 4import datetime
@@ -61,7 +60,8 @@ def parse_args():
61 "--repeats", 60 "--repeats",
62 type=int, 61 type=int,
63 default=100, 62 default=100,
64 help="How many times to repeat the training data.") 63 help="How many times to repeat the training data."
64 )
65 parser.add_argument( 65 parser.add_argument(
66 "--output_dir", 66 "--output_dir",
67 type=str, 67 type=str,
@@ -72,7 +72,8 @@ def parse_args():
72 "--seed", 72 "--seed",
73 type=int, 73 type=int,
74 default=None, 74 default=None,
75 help="A seed for reproducible training.") 75 help="A seed for reproducible training."
76 )
76 parser.add_argument( 77 parser.add_argument(
77 "--resolution", 78 "--resolution",
78 type=int, 79 type=int,
@@ -94,7 +95,7 @@ def parse_args():
94 parser.add_argument( 95 parser.add_argument(
95 "--max_train_steps", 96 "--max_train_steps",
96 type=int, 97 type=int,
97 default=5000, 98 default=1000,
98 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 99 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
99 ) 100 )
100 parser.add_argument( 101 parser.add_argument(
@@ -184,7 +185,7 @@ def parse_args():
184 parser.add_argument( 185 parser.add_argument(
185 "--checkpoint_frequency", 186 "--checkpoint_frequency",
186 type=int, 187 type=int,
187 default=500, 188 default=200,
188 help="How often to save a checkpoint and sample image", 189 help="How often to save a checkpoint and sample image",
189 ) 190 )
190 parser.add_argument( 191 parser.add_argument(
@@ -220,7 +221,7 @@ def parse_args():
220 parser.add_argument( 221 parser.add_argument(
221 "--sample_steps", 222 "--sample_steps",
222 type=int, 223 type=int,
223 default=50, 224 default=80,
224 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 225 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
225 ) 226 )
226 parser.add_argument( 227 parser.add_argument(
@@ -381,7 +382,6 @@ class Checkpointer:
381 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): 382 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps):
382 samples_path = f"{self.output_dir}/samples/{mode}" 383 samples_path = f"{self.output_dir}/samples/{mode}"
383 os.makedirs(samples_path, exist_ok=True) 384 os.makedirs(samples_path, exist_ok=True)
384 checker = NoCheck()
385 385
386 unwrapped = self.accelerator.unwrap_model(self.unet) 386 unwrapped = self.accelerator.unwrap_model(self.unet)
387 pipeline = StableDiffusionPipeline( 387 pipeline = StableDiffusionPipeline(
@@ -507,6 +507,7 @@ def main():
507 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 507 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
508 pipeline = StableDiffusionPipeline.from_pretrained( 508 pipeline = StableDiffusionPipeline.from_pretrained(
509 args.pretrained_model_name_or_path, torch_dtype=torch_dtype) 509 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
510 pipeline.enable_attention_slicing()
510 pipeline.set_progress_bar_config(disable=True) 511 pipeline.set_progress_bar_config(disable=True)
511 512
512 num_new_images = args.num_class_images - cur_class_images 513 num_new_images = args.num_class_images - cur_class_images
@@ -589,7 +590,11 @@ def main():
589 590
590 # TODO (patil-suraj): laod scheduler using args 591 # TODO (patil-suraj): laod scheduler using args
591 noise_scheduler = DDPMScheduler( 592 noise_scheduler = DDPMScheduler(
592 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" 593 beta_start=0.00085,
594 beta_end=0.012,
595 beta_schedule="scaled_linear",
596 num_train_timesteps=1000,
597 tensor_format="pt"
593 ) 598 )
594 599
595 def collate_fn(examples): 600 def collate_fn(examples):
@@ -709,7 +714,7 @@ def main():
709 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 714 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
710 715
711 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 716 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
712 local_progress_bar.set_description("Steps") 717 local_progress_bar.set_description("Steps ")
713 718
714 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 719 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
715 progress_bar.set_description("Global steps") 720 progress_bar.set_description("Global steps")
@@ -723,31 +728,31 @@ def main():
723 728
724 for step, batch in enumerate(train_dataloader): 729 for step, batch in enumerate(train_dataloader):
725 with accelerator.accumulate(unet): 730 with accelerator.accumulate(unet):
726 with accelerator.autocast(): 731 # Convert images to latent space
727 # Convert images to latent space 732 with torch.no_grad():
728 with torch.no_grad(): 733 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
729 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 734 latents = latents * 0.18215
730 latents = latents * 0.18215 735
731 736 # Sample noise that we'll add to the latents
732 # Sample noise that we'll add to the latents 737 noise = torch.randn(latents.shape).to(latents.device)
733 noise = torch.randn(latents.shape).to(latents.device) 738 bsz = latents.shape[0]
734 bsz = latents.shape[0] 739 # Sample a random timestep for each image
735 # Sample a random timestep for each image 740 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
736 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 741 (bsz,), device=latents.device)
737 (bsz,), device=latents.device) 742 timesteps = timesteps.long()
738 timesteps = timesteps.long() 743
739 744 # Add noise to the latents according to the noise magnitude at each timestep
740 # Add noise to the latents according to the noise magnitude at each timestep 745 # (this is the forward diffusion process)
741 # (this is the forward diffusion process) 746 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
742 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 747
743 748 # Get the text embedding for conditioning
744 # Get the text embedding for conditioning 749 with torch.no_grad():
745 with torch.no_grad(): 750 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
746 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 751
747 752 # Predict the noise residual
748 # Predict the noise residual 753 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
749 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
750 754
755 with accelerator.autocast():
751 if args.with_prior_preservation: 756 if args.with_prior_preservation:
752 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 757 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
753 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 758 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
@@ -766,12 +771,12 @@ def main():
766 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 771 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
767 772
768 accelerator.backward(loss) 773 accelerator.backward(loss)
769 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 774 if accelerator.sync_gradients:
770 775 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
771 optimizer.step() 776 optimizer.step()
772 if not accelerator.optimizer_step_was_skipped: 777 if not accelerator.optimizer_step_was_skipped:
773 lr_scheduler.step() 778 lr_scheduler.step()
774 optimizer.zero_grad() 779 optimizer.zero_grad(set_to_none=True)
775 780
776 loss = loss.detach().item() 781 loss = loss.detach().item()
777 train_loss += loss 782 train_loss += loss
@@ -804,7 +809,7 @@ def main():
804 val_loss = 0.0 809 val_loss = 0.0
805 810
806 for step, batch in enumerate(val_dataloader): 811 for step, batch in enumerate(val_dataloader):
807 with torch.no_grad(), accelerator.autocast(): 812 with torch.no_grad():
808 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 813 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
809 latents = latents * 0.18215 814 latents = latents * 0.18215
810 815
@@ -822,18 +827,19 @@ def main():
822 827
823 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 828 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
824 829
825 if args.with_prior_preservation: 830 with accelerator.autocast():
826 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 831 if args.with_prior_preservation:
827 noise, noise_prior = torch.chunk(noise, 2, dim=0) 832 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
833 noise, noise_prior = torch.chunk(noise, 2, dim=0)
828 834
829 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 835 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
830 836
831 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, 837 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
832 reduction="none").mean([1, 2, 3]).mean() 838 reduction="none").mean([1, 2, 3]).mean()
833 839
834 loss = loss + args.prior_loss_weight * prior_loss 840 loss = loss + args.prior_loss_weight * prior_loss
835 else: 841 else:
836 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 842 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
837 843
838 loss = loss.detach().item() 844 loss = loss.detach().item()
839 val_loss += loss 845 val_loss += loss