diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-27 22:16:13 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-27 22:16:13 +0200 |
| commit | a8a5abae42f6f42056cc27e0cf5313aab080c3a7 (patch) | |
| tree | 32c163bbc58aa2f827c5ba5108df81dc14fbe130 /dreambooth.py | |
| parent | Incorporate upstream changes (diff) | |
| download | textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.gz textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.bz2 textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.zip | |
Various improvements, added inference script
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 90 |
1 files changed, 48 insertions, 42 deletions
diff --git a/dreambooth.py b/dreambooth.py index c01cbe3..bc7a472 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -1,5 +1,4 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | ||
| 3 | import math | 2 | import math |
| 4 | import os | 3 | import os |
| 5 | import datetime | 4 | import datetime |
| @@ -61,7 +60,8 @@ def parse_args(): | |||
| 61 | "--repeats", | 60 | "--repeats", |
| 62 | type=int, | 61 | type=int, |
| 63 | default=100, | 62 | default=100, |
| 64 | help="How many times to repeat the training data.") | 63 | help="How many times to repeat the training data." |
| 64 | ) | ||
| 65 | parser.add_argument( | 65 | parser.add_argument( |
| 66 | "--output_dir", | 66 | "--output_dir", |
| 67 | type=str, | 67 | type=str, |
| @@ -72,7 +72,8 @@ def parse_args(): | |||
| 72 | "--seed", | 72 | "--seed", |
| 73 | type=int, | 73 | type=int, |
| 74 | default=None, | 74 | default=None, |
| 75 | help="A seed for reproducible training.") | 75 | help="A seed for reproducible training." |
| 76 | ) | ||
| 76 | parser.add_argument( | 77 | parser.add_argument( |
| 77 | "--resolution", | 78 | "--resolution", |
| 78 | type=int, | 79 | type=int, |
| @@ -94,7 +95,7 @@ def parse_args(): | |||
| 94 | parser.add_argument( | 95 | parser.add_argument( |
| 95 | "--max_train_steps", | 96 | "--max_train_steps", |
| 96 | type=int, | 97 | type=int, |
| 97 | default=5000, | 98 | default=1000, |
| 98 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 99 | ) | 100 | ) |
| 100 | parser.add_argument( | 101 | parser.add_argument( |
| @@ -184,7 +185,7 @@ def parse_args(): | |||
| 184 | parser.add_argument( | 185 | parser.add_argument( |
| 185 | "--checkpoint_frequency", | 186 | "--checkpoint_frequency", |
| 186 | type=int, | 187 | type=int, |
| 187 | default=500, | 188 | default=200, |
| 188 | help="How often to save a checkpoint and sample image", | 189 | help="How often to save a checkpoint and sample image", |
| 189 | ) | 190 | ) |
| 190 | parser.add_argument( | 191 | parser.add_argument( |
| @@ -220,7 +221,7 @@ def parse_args(): | |||
| 220 | parser.add_argument( | 221 | parser.add_argument( |
| 221 | "--sample_steps", | 222 | "--sample_steps", |
| 222 | type=int, | 223 | type=int, |
| 223 | default=50, | 224 | default=80, |
| 224 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 225 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 225 | ) | 226 | ) |
| 226 | parser.add_argument( | 227 | parser.add_argument( |
| @@ -381,7 +382,6 @@ class Checkpointer: | |||
| 381 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): |
| 382 | samples_path = f"{self.output_dir}/samples/{mode}" | 383 | samples_path = f"{self.output_dir}/samples/{mode}" |
| 383 | os.makedirs(samples_path, exist_ok=True) | 384 | os.makedirs(samples_path, exist_ok=True) |
| 384 | checker = NoCheck() | ||
| 385 | 385 | ||
| 386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 386 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 387 | pipeline = StableDiffusionPipeline( | 387 | pipeline = StableDiffusionPipeline( |
| @@ -507,6 +507,7 @@ def main(): | |||
| 507 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 | 507 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 |
| 508 | pipeline = StableDiffusionPipeline.from_pretrained( | 508 | pipeline = StableDiffusionPipeline.from_pretrained( |
| 509 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 509 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
| 510 | pipeline.enable_attention_slicing() | ||
| 510 | pipeline.set_progress_bar_config(disable=True) | 511 | pipeline.set_progress_bar_config(disable=True) |
| 511 | 512 | ||
| 512 | num_new_images = args.num_class_images - cur_class_images | 513 | num_new_images = args.num_class_images - cur_class_images |
| @@ -589,7 +590,11 @@ def main(): | |||
| 589 | 590 | ||
| 590 | # TODO (patil-suraj): laod scheduler using args | 591 | # TODO (patil-suraj): laod scheduler using args |
| 591 | noise_scheduler = DDPMScheduler( | 592 | noise_scheduler = DDPMScheduler( |
| 592 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | 593 | beta_start=0.00085, |
| 594 | beta_end=0.012, | ||
| 595 | beta_schedule="scaled_linear", | ||
| 596 | num_train_timesteps=1000, | ||
| 597 | tensor_format="pt" | ||
| 593 | ) | 598 | ) |
| 594 | 599 | ||
| 595 | def collate_fn(examples): | 600 | def collate_fn(examples): |
| @@ -709,7 +714,7 @@ def main(): | |||
| 709 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 714 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 710 | 715 | ||
| 711 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 716 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
| 712 | local_progress_bar.set_description("Steps") | 717 | local_progress_bar.set_description("Steps ") |
| 713 | 718 | ||
| 714 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 719 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
| 715 | progress_bar.set_description("Global steps") | 720 | progress_bar.set_description("Global steps") |
| @@ -723,31 +728,31 @@ def main(): | |||
| 723 | 728 | ||
| 724 | for step, batch in enumerate(train_dataloader): | 729 | for step, batch in enumerate(train_dataloader): |
| 725 | with accelerator.accumulate(unet): | 730 | with accelerator.accumulate(unet): |
| 726 | with accelerator.autocast(): | 731 | # Convert images to latent space |
| 727 | # Convert images to latent space | 732 | with torch.no_grad(): |
| 728 | with torch.no_grad(): | 733 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 729 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 734 | latents = latents * 0.18215 |
| 730 | latents = latents * 0.18215 | ||
| 731 | 735 | ||
| 732 | # Sample noise that we'll add to the latents | 736 | # Sample noise that we'll add to the latents |
| 733 | noise = torch.randn(latents.shape).to(latents.device) | 737 | noise = torch.randn(latents.shape).to(latents.device) |
| 734 | bsz = latents.shape[0] | 738 | bsz = latents.shape[0] |
| 735 | # Sample a random timestep for each image | 739 | # Sample a random timestep for each image |
| 736 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 740 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 737 | (bsz,), device=latents.device) | 741 | (bsz,), device=latents.device) |
| 738 | timesteps = timesteps.long() | 742 | timesteps = timesteps.long() |
| 739 | 743 | ||
| 740 | # Add noise to the latents according to the noise magnitude at each timestep | 744 | # Add noise to the latents according to the noise magnitude at each timestep |
| 741 | # (this is the forward diffusion process) | 745 | # (this is the forward diffusion process) |
| 742 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 746 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 743 | 747 | ||
| 744 | # Get the text embedding for conditioning | 748 | # Get the text embedding for conditioning |
| 745 | with torch.no_grad(): | 749 | with torch.no_grad(): |
| 746 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 750 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
| 747 | 751 | ||
| 748 | # Predict the noise residual | 752 | # Predict the noise residual |
| 749 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 753 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 750 | 754 | ||
| 755 | with accelerator.autocast(): | ||
| 751 | if args.with_prior_preservation: | 756 | if args.with_prior_preservation: |
| 752 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 757 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 753 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 758 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| @@ -766,12 +771,12 @@ def main(): | |||
| 766 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 771 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 767 | 772 | ||
| 768 | accelerator.backward(loss) | 773 | accelerator.backward(loss) |
| 769 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 774 | if accelerator.sync_gradients: |
| 770 | 775 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | |
| 771 | optimizer.step() | 776 | optimizer.step() |
| 772 | if not accelerator.optimizer_step_was_skipped: | 777 | if not accelerator.optimizer_step_was_skipped: |
| 773 | lr_scheduler.step() | 778 | lr_scheduler.step() |
| 774 | optimizer.zero_grad() | 779 | optimizer.zero_grad(set_to_none=True) |
| 775 | 780 | ||
| 776 | loss = loss.detach().item() | 781 | loss = loss.detach().item() |
| 777 | train_loss += loss | 782 | train_loss += loss |
| @@ -804,7 +809,7 @@ def main(): | |||
| 804 | val_loss = 0.0 | 809 | val_loss = 0.0 |
| 805 | 810 | ||
| 806 | for step, batch in enumerate(val_dataloader): | 811 | for step, batch in enumerate(val_dataloader): |
| 807 | with torch.no_grad(), accelerator.autocast(): | 812 | with torch.no_grad(): |
| 808 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 813 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 809 | latents = latents * 0.18215 | 814 | latents = latents * 0.18215 |
| 810 | 815 | ||
| @@ -822,18 +827,19 @@ def main(): | |||
| 822 | 827 | ||
| 823 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 828 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 824 | 829 | ||
| 825 | if args.with_prior_preservation: | 830 | with accelerator.autocast(): |
| 826 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 831 | if args.with_prior_preservation: |
| 827 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 832 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 833 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 828 | 834 | ||
| 829 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 830 | 836 | ||
| 831 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 837 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, |
| 832 | reduction="none").mean([1, 2, 3]).mean() | 838 | reduction="none").mean([1, 2, 3]).mean() |
| 833 | 839 | ||
| 834 | loss = loss + args.prior_loss_weight * prior_loss | 840 | loss = loss + args.prior_loss_weight * prior_loss |
| 835 | else: | 841 | else: |
| 836 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 842 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 837 | 843 | ||
| 838 | loss = loss.detach().item() | 844 | loss = loss.detach().item() |
| 839 | val_loss += loss | 845 | val_loss += loss |
