summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
commitadc52fb8821a496bc8d78235bf10466b39df03e0 (patch)
tree8a6337a6ac10cbe76c55514ab559c647e69fb1aa /train_dreambooth.py
parentFixed accuracy calc, other improvements (diff)
downloadtextual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip
Updates
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py228
1 files changed, 124 insertions, 104 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1ebcfe3..b07de31 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -3,7 +3,6 @@ import itertools
3import math 3import math
4import datetime 4import datetime
5import logging 5import logging
6import json
7from pathlib import Path 6from pathlib import Path
8 7
9import torch 8import torch
@@ -15,18 +14,21 @@ from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm 19from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings, load_config 23from common import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
31from models.clip.tokenizer import MultiCLIPTokenizer
30 32
31logger = get_logger(__name__) 33logger = get_logger(__name__)
32 34
@@ -106,6 +108,12 @@ def parse_args():
106 help="Tag dropout probability.", 108 help="Tag dropout probability.",
107 ) 109 )
108 parser.add_argument( 110 parser.add_argument(
111 "--vector_shuffle",
112 type=str,
113 default="auto",
114 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
115 )
116 parser.add_argument(
109 "--num_class_images", 117 "--num_class_images",
110 type=int, 118 type=int,
111 default=1, 119 default=1,
@@ -193,13 +201,12 @@ def parse_args():
193 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 201 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
194 ) 202 )
195 parser.add_argument( 203 parser.add_argument(
196 "--learning_rate_unet", 204 "--find_lr",
197 type=float, 205 action="store_true",
198 default=2e-6, 206 help="Automatically find a learning rate (no training).",
199 help="Initial learning rate (after the potential warmup period) to use.",
200 ) 207 )
201 parser.add_argument( 208 parser.add_argument(
202 "--learning_rate_text", 209 "--learning_rate",
203 type=float, 210 type=float,
204 default=2e-6, 211 default=2e-6,
205 help="Initial learning rate (after the potential warmup period) to use.", 212 help="Initial learning rate (after the potential warmup period) to use.",
@@ -546,9 +553,9 @@ def main():
546 553
547 # Load the tokenizer and add the placeholder token as a additional special token 554 # Load the tokenizer and add the placeholder token as a additional special token
548 if args.tokenizer_name: 555 if args.tokenizer_name:
549 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 556 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
550 elif args.pretrained_model_name_or_path: 557 elif args.pretrained_model_name_or_path:
551 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 558 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
552 559
553 # Load models and create wrapper for stable diffusion 560 # Load models and create wrapper for stable diffusion
554 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 561 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -558,6 +565,8 @@ def main():
558 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 565 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
559 args.pretrained_model_name_or_path, subfolder='scheduler') 566 args.pretrained_model_name_or_path, subfolder='scheduler')
560 567
568 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
569
561 vae.enable_slicing() 570 vae.enable_slicing()
562 vae.set_use_memory_efficient_attention_xformers(True) 571 vae.set_use_memory_efficient_attention_xformers(True)
563 unet.set_use_memory_efficient_attention_xformers(True) 572 unet.set_use_memory_efficient_attention_xformers(True)
@@ -576,46 +585,42 @@ def main():
576 device=accelerator.device 585 device=accelerator.device
577 ) 586 )
578 587
579 # Freeze text_encoder and vae 588 embeddings = patch_managed_embeddings(text_encoder)
580 vae.requires_grad_(False)
581 589
582 if args.embeddings_dir is not None: 590 if args.embeddings_dir is not None:
583 embeddings_dir = Path(args.embeddings_dir) 591 embeddings_dir = Path(args.embeddings_dir)
584 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 592 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
585 raise ValueError("--embeddings_dir must point to an existing directory") 593 raise ValueError("--embeddings_dir must point to an existing directory")
586 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) 594
587 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 595 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
596 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
588 597
589 if len(args.placeholder_token) != 0: 598 if len(args.placeholder_token) != 0:
590 # Convert the initializer_token, placeholder_token to ids 599 # Convert the initializer_token, placeholder_token to ids
591 initializer_token_ids = torch.stack([ 600 initializer_token_ids = [
592 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 601 tokenizer.encode(token, add_special_tokens=False)
593 for token in args.initializer_token 602 for token in args.initializer_token
594 ]) 603 ]
595
596 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
597 print(f"Added {num_added_tokens} new tokens.")
598 604
599 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 605 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
606 embeddings.resize(len(tokenizer))
600 607
601 # Resize the token embeddings as we are adding new special tokens to the tokenizer 608 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
602 text_encoder.resize_token_embeddings(len(tokenizer)) 609 embeddings.add_embed(new_token.ids, init_ids)
603 610
604 token_embeds = text_encoder.get_input_embeddings().weight.data 611 print(f"Added {len(new_tokens)} new tokens.")
605 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
606
607 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
608 token_embeds[token_id] = embeddings
609 else: 612 else:
610 placeholder_token_id = [] 613 placeholder_token_id = []
611 614
615 vae.requires_grad_(False)
616
612 if args.train_text_encoder: 617 if args.train_text_encoder:
613 print(f"Training entire text encoder.") 618 print(f"Training entire text encoder.")
619
620 unpatch_managed_embeddings(text_encoder)
614 else: 621 else:
615 print(f"Training added text embeddings") 622 print(f"Training added text embeddings")
616 623
617 patch_trainable_embeddings(text_encoder, placeholder_token_id)
618
619 text_encoder.text_model.encoder.requires_grad_(False) 624 text_encoder.text_model.encoder.requires_grad_(False)
620 text_encoder.text_model.final_layer_norm.requires_grad_(False) 625 text_encoder.text_model.final_layer_norm.requires_grad_(False)
621 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 626 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
@@ -624,15 +629,14 @@ def main():
624 prompt_processor = PromptProcessor(tokenizer, text_encoder) 629 prompt_processor = PromptProcessor(tokenizer, text_encoder)
625 630
626 if args.scale_lr: 631 if args.scale_lr:
627 args.learning_rate_unet = ( 632 args.learning_rate = (
628 args.learning_rate_unet * args.gradient_accumulation_steps * 633 args.learning_rate * args.gradient_accumulation_steps *
629 args.train_batch_size * accelerator.num_processes
630 )
631 args.learning_rate_text = (
632 args.learning_rate_text * args.gradient_accumulation_steps *
633 args.train_batch_size * accelerator.num_processes 634 args.train_batch_size * accelerator.num_processes
634 ) 635 )
635 636
637 if args.find_lr:
638 args.learning_rate = 1e2
639
636 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 640 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
637 if args.use_8bit_adam: 641 if args.use_8bit_adam:
638 try: 642 try:
@@ -647,20 +651,19 @@ def main():
647 if args.train_text_encoder: 651 if args.train_text_encoder:
648 text_encoder_params_to_optimize = text_encoder.parameters() 652 text_encoder_params_to_optimize = text_encoder.parameters()
649 else: 653 else:
650 text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() 654 text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters()
651 655
652 # Initialize the optimizer 656 # Initialize the optimizer
653 optimizer = optimizer_class( 657 optimizer = optimizer_class(
654 [ 658 [
655 { 659 {
656 'params': unet.parameters(), 660 'params': unet.parameters(),
657 'lr': args.learning_rate_unet,
658 }, 661 },
659 { 662 {
660 'params': text_encoder_params_to_optimize, 663 'params': text_encoder_params_to_optimize,
661 'lr': args.learning_rate_text,
662 } 664 }
663 ], 665 ],
666 lr=args.learning_rate,
664 betas=(args.adam_beta1, args.adam_beta2), 667 betas=(args.adam_beta1, args.adam_beta2),
665 weight_decay=args.adam_weight_decay, 668 weight_decay=args.adam_weight_decay,
666 eps=args.adam_epsilon, 669 eps=args.adam_epsilon,
@@ -824,6 +827,58 @@ def main():
824 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 827 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
825 val_steps = num_val_steps_per_epoch * num_epochs 828 val_steps = num_val_steps_per_epoch * num_epochs
826 829
830 def loop(batch):
831 # Convert images to latent space
832 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
833 latents = latents * 0.18215
834
835 # Sample noise that we'll add to the latents
836 noise = torch.randn_like(latents)
837 bsz = latents.shape[0]
838 # Sample a random timestep for each image
839 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
840 (bsz,), device=latents.device)
841 timesteps = timesteps.long()
842
843 # Add noise to the latents according to the noise magnitude at each timestep
844 # (this is the forward diffusion process)
845 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
846 noisy_latents = noisy_latents.to(dtype=unet.dtype)
847
848 # Get the text embedding for conditioning
849 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
850
851 # Predict the noise residual
852 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
853
854 # Get the target for loss depending on the prediction type
855 if noise_scheduler.config.prediction_type == "epsilon":
856 target = noise
857 elif noise_scheduler.config.prediction_type == "v_prediction":
858 target = noise_scheduler.get_velocity(latents, noise, timesteps)
859 else:
860 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
861
862 if args.num_class_images != 0:
863 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
864 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
865 target, target_prior = torch.chunk(target, 2, dim=0)
866
867 # Compute instance loss
868 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
869
870 # Compute prior loss
871 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
872
873 # Add the prior loss to the instance loss.
874 loss = loss + args.prior_loss_weight * prior_loss
875 else:
876 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
877
878 acc = (model_pred == target).float().mean()
879
880 return loss, acc, bsz
881
827 # We need to initialize the trackers we use, and also store our configuration. 882 # We need to initialize the trackers we use, and also store our configuration.
828 # The trackers initializes automatically on the main process. 883 # The trackers initializes automatically on the main process.
829 if accelerator.is_main_process: 884 if accelerator.is_main_process:
@@ -836,6 +891,15 @@ def main():
836 config["exclude_collections"] = " ".join(config["exclude_collections"]) 891 config["exclude_collections"] = " ".join(config["exclude_collections"])
837 accelerator.init_trackers("dreambooth", config=config) 892 accelerator.init_trackers("dreambooth", config=config)
838 893
894 if args.find_lr:
895 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
896 lr_finder.run(min_lr=1e-4)
897
898 plt.savefig(basepath.joinpath("lr.png"))
899 plt.close()
900
901 quit()
902
839 # Train! 903 # Train!
840 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 904 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
841 905
@@ -893,58 +957,6 @@ def main():
893 ) 957 )
894 global_progress_bar.set_description("Total progress") 958 global_progress_bar.set_description("Total progress")
895 959
896 def loop(batch):
897 # Convert images to latent space
898 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
899 latents = latents * 0.18215
900
901 # Sample noise that we'll add to the latents
902 noise = torch.randn_like(latents)
903 bsz = latents.shape[0]
904 # Sample a random timestep for each image
905 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
906 (bsz,), device=latents.device)
907 timesteps = timesteps.long()
908
909 # Add noise to the latents according to the noise magnitude at each timestep
910 # (this is the forward diffusion process)
911 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
912 noisy_latents = noisy_latents.to(dtype=unet.dtype)
913
914 # Get the text embedding for conditioning
915 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
916
917 # Predict the noise residual
918 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
919
920 # Get the target for loss depending on the prediction type
921 if noise_scheduler.config.prediction_type == "epsilon":
922 target = noise
923 elif noise_scheduler.config.prediction_type == "v_prediction":
924 target = noise_scheduler.get_velocity(latents, noise, timesteps)
925 else:
926 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
927
928 if args.num_class_images != 0:
929 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
930 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
931 target, target_prior = torch.chunk(target, 2, dim=0)
932
933 # Compute instance loss
934 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
935
936 # Compute prior loss
937 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
938
939 # Add the prior loss to the instance loss.
940 loss = loss + args.prior_loss_weight * prior_loss
941 else:
942 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
943
944 acc = (model_pred == target).float().mean()
945
946 return loss, acc, bsz
947
948 try: 960 try:
949 for epoch in range(num_epochs): 961 for epoch in range(num_epochs):
950 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 962 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -993,8 +1005,7 @@ def main():
993 "train/acc": avg_acc.avg.item(), 1005 "train/acc": avg_acc.avg.item(),
994 "train/cur_loss": loss.item(), 1006 "train/cur_loss": loss.item(),
995 "train/cur_acc": acc.item(), 1007 "train/cur_acc": acc.item(),
996 "lr/unet": lr_scheduler.get_last_lr()[0], 1008 "lr": lr_scheduler.get_last_lr()[0]
997 "lr/text": lr_scheduler.get_last_lr()[1]
998 } 1009 }
999 if args.use_ema: 1010 if args.use_ema:
1000 logs["ema_decay"] = 1 - ema_unet.decay 1011 logs["ema_decay"] = 1 - ema_unet.decay
@@ -1011,12 +1022,21 @@ def main():
1011 unet.eval() 1022 unet.eval()
1012 text_encoder.eval() 1023 text_encoder.eval()
1013 1024
1025 cur_loss_val = AverageMeter()
1026 cur_acc_val = AverageMeter()
1027
1014 with torch.inference_mode(): 1028 with torch.inference_mode():
1015 for step, batch in enumerate(val_dataloader): 1029 for step, batch in enumerate(val_dataloader):
1016 loss, acc, bsz = loop(batch) 1030 loss, acc, bsz = loop(batch)
1017 1031
1018 avg_loss_val.update(loss.detach_(), bsz) 1032 loss = loss.detach_()
1019 avg_acc_val.update(acc.detach_(), bsz) 1033 acc = acc.detach_()
1034
1035 cur_loss_val.update(loss, bsz)
1036 cur_acc_val.update(acc, bsz)
1037
1038 avg_loss_val.update(loss, bsz)
1039 avg_acc_val.update(acc, bsz)
1020 1040
1021 local_progress_bar.update(1) 1041 local_progress_bar.update(1)
1022 global_progress_bar.update(1) 1042 global_progress_bar.update(1)
@@ -1029,20 +1049,20 @@ def main():
1029 } 1049 }
1030 local_progress_bar.set_postfix(**logs) 1050 local_progress_bar.set_postfix(**logs)
1031 1051
1032 accelerator.log({ 1052 logs["val/cur_loss"] = cur_loss_val.avg.item()
1033 "val/loss": avg_loss_val.avg.item(), 1053 logs["val/cur_acc"] = cur_acc_val.avg.item()
1034 "val/acc": avg_acc_val.avg.item(), 1054
1035 }, step=global_step) 1055 accelerator.log(logs, step=global_step)
1036 1056
1037 local_progress_bar.clear() 1057 local_progress_bar.clear()
1038 global_progress_bar.clear() 1058 global_progress_bar.clear()
1039 1059
1040 if avg_acc_val.avg.item() > max_acc_val:
1041 accelerator.print(
1042 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1043 max_acc_val = avg_acc_val.avg.item()
1044
1045 if accelerator.is_main_process: 1060 if accelerator.is_main_process:
1061 if avg_acc_val.avg.item() > max_acc_val:
1062 accelerator.print(
1063 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1064 max_acc_val = avg_acc_val.avg.item()
1065
1046 if (epoch + 1) % args.sample_frequency == 0: 1066 if (epoch + 1) % args.sample_frequency == 0:
1047 checkpointer.save_samples(global_step, args.sample_steps) 1067 checkpointer.save_samples(global_step, args.sample_steps)
1048 1068