summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
commit127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch)
tree61cb98adbf33ed08506601f8b70f1b62bc42c4ee /train_dreambooth.py
parentSimplified step calculations (diff)
downloadtextual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip
More modularization
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py272
1 files changed, 65 insertions, 207 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index fbbe6c2..c892ebf 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -1,6 +1,5 @@
1import argparse 1import argparse
2import itertools 2import itertools
3import math
4import datetime 3import datetime
5import logging 4import logging
6from pathlib import Path 5from pathlib import Path
@@ -16,16 +15,15 @@ from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17import matplotlib.pyplot as plt 16import matplotlib.pyplot as plt
18from diffusers.training_utils import EMAModel 17from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm
20from transformers import CLIPTextModel 18from transformers import CLIPTextModel
21from slugify import slugify 19from slugify import slugify
22 20
23from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
26from training.common import loss_step, generate_class_images, get_scheduler 24from training.common import loss_step, train_loop, generate_class_images, get_scheduler
27from training.lr import LRFinder 25from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 26from training.util import CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings 27from models.clip.embeddings import patch_managed_embeddings
30from models.clip.tokenizer import MultiCLIPTokenizer 28from models.clip.tokenizer import MultiCLIPTokenizer
31 29
@@ -292,7 +290,7 @@ def parse_args():
292 parser.add_argument( 290 parser.add_argument(
293 "--lr_min_lr", 291 "--lr_min_lr",
294 type=float, 292 type=float,
295 default=None, 293 default=0.04,
296 help="Minimum learning rate in the lr scheduler." 294 help="Minimum learning rate in the lr scheduler."
297 ) 295 )
298 parser.add_argument( 296 parser.add_argument(
@@ -787,14 +785,6 @@ def main():
787 args.sample_steps 785 args.sample_steps
788 ) 786 )
789 787
790 # Scheduler and math around the number of training steps.
791 overrode_max_train_steps = False
792 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
793 if args.max_train_steps is None:
794 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
795 overrode_max_train_steps = True
796 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
797
798 if args.find_lr: 788 if args.find_lr:
799 lr_scheduler = None 789 lr_scheduler = None
800 else: 790 else:
@@ -802,15 +792,14 @@ def main():
802 args.lr_scheduler, 792 args.lr_scheduler,
803 optimizer=optimizer, 793 optimizer=optimizer,
804 min_lr=args.lr_min_lr, 794 min_lr=args.lr_min_lr,
805 lr=args.learning_rate,
806 warmup_func=args.lr_warmup_func, 795 warmup_func=args.lr_warmup_func,
807 annealing_func=args.lr_annealing_func, 796 annealing_func=args.lr_annealing_func,
808 warmup_exp=args.lr_warmup_exp, 797 warmup_exp=args.lr_warmup_exp,
809 annealing_exp=args.lr_annealing_exp, 798 annealing_exp=args.lr_annealing_exp,
810 cycles=args.lr_cycles, 799 cycles=args.lr_cycles,
800 train_epochs=args.num_train_epochs,
811 warmup_epochs=args.lr_warmup_epochs, 801 warmup_epochs=args.lr_warmup_epochs,
812 max_train_steps=args.max_train_steps, 802 num_training_steps_per_epoch=len(train_dataloader),
813 num_update_steps_per_epoch=num_update_steps_per_epoch,
814 gradient_accumulation_steps=args.gradient_accumulation_steps 803 gradient_accumulation_steps=args.gradient_accumulation_steps
815 ) 804 )
816 805
@@ -827,19 +816,16 @@ def main():
827 if args.use_ema: 816 if args.use_ema:
828 ema_unet.to(accelerator.device) 817 ema_unet.to(accelerator.device)
829 818
830 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
831 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
832 if overrode_max_train_steps:
833 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
834
835 num_val_steps_per_epoch = len(val_dataloader)
836 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
837 val_steps = num_val_steps_per_epoch * num_epochs
838
839 @contextmanager 819 @contextmanager
840 def on_train(): 820 def on_train(epoch: int):
841 try: 821 try:
842 tokenizer.train() 822 tokenizer.train()
823
824 if epoch < args.train_text_encoder_epochs:
825 text_encoder.train()
826 elif epoch == args.train_text_encoder_epochs:
827 text_encoder.requires_grad_(False)
828
843 yield 829 yield
844 finally: 830 finally:
845 pass 831 pass
@@ -848,6 +834,7 @@ def main():
848 def on_eval(): 834 def on_eval():
849 try: 835 try:
850 tokenizer.eval() 836 tokenizer.eval()
837 text_encoder.eval()
851 838
852 ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() 839 ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext()
853 840
@@ -856,7 +843,7 @@ def main():
856 finally: 843 finally:
857 pass 844 pass
858 845
859 def on_before_optimize(): 846 def on_before_optimize(epoch: int):
860 if accelerator.sync_gradients: 847 if accelerator.sync_gradients:
861 params_to_clip = [unet.parameters()] 848 params_to_clip = [unet.parameters()]
862 if args.train_text_encoder and epoch < args.train_text_encoder_epochs: 849 if args.train_text_encoder and epoch < args.train_text_encoder_epochs:
@@ -866,9 +853,17 @@ def main():
866 @torch.no_grad() 853 @torch.no_grad()
867 def on_after_optimize(lr: float): 854 def on_after_optimize(lr: float):
868 if not args.train_text_encoder: 855 if not args.train_text_encoder:
869 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) 856 text_encoder.text_model.embeddings.normalize(
857 args.decay_target,
858 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
859 )
860
861 def on_log():
862 if args.use_ema:
863 return {"ema_decay": ema_unet.decay}
864 return {}
870 865
871 loop = partial( 866 loss_step_ = partial(
872 loss_step, 867 loss_step,
873 vae, 868 vae,
874 noise_scheduler, 869 noise_scheduler,
@@ -879,8 +874,25 @@ def main():
879 args.seed, 874 args.seed,
880 ) 875 )
881 876
882 # We need to initialize the trackers we use, and also store our configuration. 877 checkpointer = Checkpointer(
883 # The trackers initializes automatically on the main process. 878 weight_dtype=weight_dtype,
879 datamodule=datamodule,
880 accelerator=accelerator,
881 vae=vae,
882 unet=unet,
883 ema_unet=ema_unet,
884 tokenizer=tokenizer,
885 text_encoder=text_encoder,
886 scheduler=checkpoint_scheduler,
887 output_dir=basepath,
888 placeholder_token=args.placeholder_token,
889 placeholder_token_id=placeholder_token_id,
890 sample_image_size=args.sample_image_size,
891 sample_batch_size=args.sample_batch_size,
892 sample_batches=args.sample_batches,
893 seed=args.seed
894 )
895
884 if accelerator.is_main_process: 896 if accelerator.is_main_process:
885 config = vars(args).copy() 897 config = vars(args).copy()
886 config["initializer_token"] = " ".join(config["initializer_token"]) 898 config["initializer_token"] = " ".join(config["initializer_token"])
@@ -898,9 +910,9 @@ def main():
898 optimizer, 910 optimizer,
899 train_dataloader, 911 train_dataloader,
900 val_dataloader, 912 val_dataloader,
901 loop, 913 loss_step_,
902 on_train=tokenizer.train, 914 on_train=on_train,
903 on_eval=tokenizer.eval, 915 on_eval=on_eval,
904 on_before_optimize=on_before_optimize, 916 on_before_optimize=on_before_optimize,
905 on_after_optimize=on_after_optimize, 917 on_after_optimize=on_after_optimize,
906 ) 918 )
@@ -909,182 +921,28 @@ def main():
909 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 921 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
910 plt.close() 922 plt.close()
911 923
912 quit() 924 return
913
914 # Train!
915 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
916
917 logger.info("***** Running training *****")
918 logger.info(f" Num Epochs = {num_epochs}")
919 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
920 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
921 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
922 logger.info(f" Total optimization steps = {args.max_train_steps}")
923 # Only show the progress bar once on each machine.
924
925 global_step = 0
926
927 avg_loss = AverageMeter()
928 avg_acc = AverageMeter()
929 925
930 avg_loss_val = AverageMeter() 926 train_loop(
931 avg_acc_val = AverageMeter()
932
933 max_acc_val = 0.0
934
935 checkpointer = Checkpointer(
936 weight_dtype=weight_dtype,
937 datamodule=datamodule,
938 accelerator=accelerator, 927 accelerator=accelerator,
939 vae=vae, 928 optimizer=optimizer,
940 unet=unet, 929 lr_scheduler=lr_scheduler,
941 ema_unet=ema_unet, 930 model=unet,
942 tokenizer=tokenizer, 931 checkpointer=checkpointer,
943 text_encoder=text_encoder, 932 train_dataloader=train_dataloader,
944 scheduler=checkpoint_scheduler, 933 val_dataloader=val_dataloader,
945 output_dir=basepath, 934 loss_step=loss_step_,
946 placeholder_token=args.placeholder_token, 935 sample_frequency=args.sample_frequency,
947 placeholder_token_id=placeholder_token_id, 936 sample_steps=args.sample_steps,
948 sample_image_size=args.sample_image_size, 937 checkpoint_frequency=args.checkpoint_frequency,
949 sample_batch_size=args.sample_batch_size, 938 global_step_offset=0,
950 sample_batches=args.sample_batches, 939 gradient_accumulation_steps=args.gradient_accumulation_steps,
951 seed=args.seed 940 num_epochs=args.num_train_epochs,
952 ) 941 on_log=on_log,
953 942 on_train=on_train,
954 local_progress_bar = tqdm( 943 on_after_optimize=on_after_optimize,
955 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 944 on_eval=on_eval
956 disable=not accelerator.is_local_main_process,
957 dynamic_ncols=True
958 )
959 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
960
961 global_progress_bar = tqdm(
962 range(args.max_train_steps + val_steps),
963 disable=not accelerator.is_local_main_process,
964 dynamic_ncols=True
965 ) 945 )
966 global_progress_bar.set_description("Total progress")
967
968 try:
969 for epoch in range(num_epochs):
970 if accelerator.is_main_process:
971 if epoch % args.sample_frequency == 0:
972 checkpointer.save_samples(global_step, args.sample_steps)
973
974 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
975 local_progress_bar.reset()
976
977 unet.train()
978 if epoch < args.train_text_encoder_epochs:
979 text_encoder.train()
980 elif epoch == args.train_text_encoder_epochs:
981 text_encoder.requires_grad_(False)
982
983 with on_train():
984 for step, batch in enumerate(train_dataloader):
985 with accelerator.accumulate(unet):
986 loss, acc, bsz = loop(step, batch)
987
988 accelerator.backward(loss)
989
990 on_before_optimize()
991
992 optimizer.step()
993 if not accelerator.optimizer_step_was_skipped:
994 lr_scheduler.step()
995 if args.use_ema:
996 ema_unet.step(unet.parameters())
997 optimizer.zero_grad(set_to_none=True)
998
999 avg_loss.update(loss.detach_(), bsz)
1000 avg_acc.update(acc.detach_(), bsz)
1001
1002 # Checks if the accelerator has performed an optimization step behind the scenes
1003 if accelerator.sync_gradients:
1004 on_after_optimize(lr_scheduler.get_last_lr()[0])
1005
1006 local_progress_bar.update(1)
1007 global_progress_bar.update(1)
1008
1009 global_step += 1
1010
1011 logs = {
1012 "train/loss": avg_loss.avg.item(),
1013 "train/acc": avg_acc.avg.item(),
1014 "train/cur_loss": loss.item(),
1015 "train/cur_acc": acc.item(),
1016 "lr": lr_scheduler.get_last_lr()[0]
1017 }
1018 if args.use_ema:
1019 logs["ema_decay"] = 1 - ema_unet.decay
1020
1021 accelerator.log(logs, step=global_step)
1022
1023 local_progress_bar.set_postfix(**logs)
1024
1025 if global_step >= args.max_train_steps:
1026 break
1027
1028 accelerator.wait_for_everyone()
1029
1030 unet.eval()
1031 text_encoder.eval()
1032
1033 cur_loss_val = AverageMeter()
1034 cur_acc_val = AverageMeter()
1035
1036 with torch.inference_mode():
1037 with on_eval():
1038 for step, batch in enumerate(val_dataloader):
1039 loss, acc, bsz = loop(step, batch, True)
1040
1041 loss = loss.detach_()
1042 acc = acc.detach_()
1043
1044 cur_loss_val.update(loss, bsz)
1045 cur_acc_val.update(acc, bsz)
1046
1047 avg_loss_val.update(loss, bsz)
1048 avg_acc_val.update(acc, bsz)
1049
1050 local_progress_bar.update(1)
1051 global_progress_bar.update(1)
1052
1053 logs = {
1054 "val/loss": avg_loss_val.avg.item(),
1055 "val/acc": avg_acc_val.avg.item(),
1056 "val/cur_loss": loss.item(),
1057 "val/cur_acc": acc.item(),
1058 }
1059 local_progress_bar.set_postfix(**logs)
1060
1061 logs["val/cur_loss"] = cur_loss_val.avg.item()
1062 logs["val/cur_acc"] = cur_acc_val.avg.item()
1063
1064 accelerator.log(logs, step=global_step)
1065
1066 local_progress_bar.clear()
1067 global_progress_bar.clear()
1068
1069 if accelerator.is_main_process:
1070 if avg_acc_val.avg.item() > max_acc_val:
1071 accelerator.print(
1072 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1073 max_acc_val = avg_acc_val.avg.item()
1074
1075 # Create the pipeline using using the trained modules and save it.
1076 if accelerator.is_main_process:
1077 print("Finished! Saving final checkpoint and resume state.")
1078 checkpointer.save_samples(global_step, args.sample_steps)
1079 checkpointer.save_model()
1080 accelerator.end_training()
1081
1082 except KeyboardInterrupt:
1083 if accelerator.is_main_process:
1084 print("Interrupted, saving checkpoint and resume state...")
1085 checkpointer.save_model()
1086 accelerator.end_training()
1087 quit()
1088 946
1089 947
1090if __name__ == "__main__": 948if __name__ == "__main__":