diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
commit | 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch) | |
tree | 61cb98adbf33ed08506601f8b70f1b62bc42c4ee /train_dreambooth.py | |
parent | Simplified step calculations (diff) | |
download | textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2 textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip |
More modularization
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 272 |
1 files changed, 65 insertions, 207 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index fbbe6c2..c892ebf 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -1,6 +1,5 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | 2 | import itertools |
3 | import math | ||
4 | import datetime | 3 | import datetime |
5 | import logging | 4 | import logging |
6 | from pathlib import Path | 5 | from pathlib import Path |
@@ -16,16 +15,15 @@ from accelerate.utils import LoggerType, set_seed | |||
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
18 | from diffusers.training_utils import EMAModel | 17 | from diffusers.training_utils import EMAModel |
19 | from tqdm.auto import tqdm | ||
20 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
21 | from slugify import slugify | 19 | from slugify import slugify |
22 | 20 | ||
23 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
26 | from training.common import loss_step, generate_class_images, get_scheduler | 24 | from training.common import loss_step, train_loop, generate_class_images, get_scheduler |
27 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 26 | from training.util import CheckpointerBase, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
30 | from models.clip.tokenizer import MultiCLIPTokenizer | 28 | from models.clip.tokenizer import MultiCLIPTokenizer |
31 | 29 | ||
@@ -292,7 +290,7 @@ def parse_args(): | |||
292 | parser.add_argument( | 290 | parser.add_argument( |
293 | "--lr_min_lr", | 291 | "--lr_min_lr", |
294 | type=float, | 292 | type=float, |
295 | default=None, | 293 | default=0.04, |
296 | help="Minimum learning rate in the lr scheduler." | 294 | help="Minimum learning rate in the lr scheduler." |
297 | ) | 295 | ) |
298 | parser.add_argument( | 296 | parser.add_argument( |
@@ -787,14 +785,6 @@ def main(): | |||
787 | args.sample_steps | 785 | args.sample_steps |
788 | ) | 786 | ) |
789 | 787 | ||
790 | # Scheduler and math around the number of training steps. | ||
791 | overrode_max_train_steps = False | ||
792 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
793 | if args.max_train_steps is None: | ||
794 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
795 | overrode_max_train_steps = True | ||
796 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
797 | |||
798 | if args.find_lr: | 788 | if args.find_lr: |
799 | lr_scheduler = None | 789 | lr_scheduler = None |
800 | else: | 790 | else: |
@@ -802,15 +792,14 @@ def main(): | |||
802 | args.lr_scheduler, | 792 | args.lr_scheduler, |
803 | optimizer=optimizer, | 793 | optimizer=optimizer, |
804 | min_lr=args.lr_min_lr, | 794 | min_lr=args.lr_min_lr, |
805 | lr=args.learning_rate, | ||
806 | warmup_func=args.lr_warmup_func, | 795 | warmup_func=args.lr_warmup_func, |
807 | annealing_func=args.lr_annealing_func, | 796 | annealing_func=args.lr_annealing_func, |
808 | warmup_exp=args.lr_warmup_exp, | 797 | warmup_exp=args.lr_warmup_exp, |
809 | annealing_exp=args.lr_annealing_exp, | 798 | annealing_exp=args.lr_annealing_exp, |
810 | cycles=args.lr_cycles, | 799 | cycles=args.lr_cycles, |
800 | train_epochs=args.num_train_epochs, | ||
811 | warmup_epochs=args.lr_warmup_epochs, | 801 | warmup_epochs=args.lr_warmup_epochs, |
812 | max_train_steps=args.max_train_steps, | 802 | num_training_steps_per_epoch=len(train_dataloader), |
813 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
814 | gradient_accumulation_steps=args.gradient_accumulation_steps | 803 | gradient_accumulation_steps=args.gradient_accumulation_steps |
815 | ) | 804 | ) |
816 | 805 | ||
@@ -827,19 +816,16 @@ def main(): | |||
827 | if args.use_ema: | 816 | if args.use_ema: |
828 | ema_unet.to(accelerator.device) | 817 | ema_unet.to(accelerator.device) |
829 | 818 | ||
830 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
831 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
832 | if overrode_max_train_steps: | ||
833 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
834 | |||
835 | num_val_steps_per_epoch = len(val_dataloader) | ||
836 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
837 | val_steps = num_val_steps_per_epoch * num_epochs | ||
838 | |||
839 | @contextmanager | 819 | @contextmanager |
840 | def on_train(): | 820 | def on_train(epoch: int): |
841 | try: | 821 | try: |
842 | tokenizer.train() | 822 | tokenizer.train() |
823 | |||
824 | if epoch < args.train_text_encoder_epochs: | ||
825 | text_encoder.train() | ||
826 | elif epoch == args.train_text_encoder_epochs: | ||
827 | text_encoder.requires_grad_(False) | ||
828 | |||
843 | yield | 829 | yield |
844 | finally: | 830 | finally: |
845 | pass | 831 | pass |
@@ -848,6 +834,7 @@ def main(): | |||
848 | def on_eval(): | 834 | def on_eval(): |
849 | try: | 835 | try: |
850 | tokenizer.eval() | 836 | tokenizer.eval() |
837 | text_encoder.eval() | ||
851 | 838 | ||
852 | ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() | 839 | ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() |
853 | 840 | ||
@@ -856,7 +843,7 @@ def main(): | |||
856 | finally: | 843 | finally: |
857 | pass | 844 | pass |
858 | 845 | ||
859 | def on_before_optimize(): | 846 | def on_before_optimize(epoch: int): |
860 | if accelerator.sync_gradients: | 847 | if accelerator.sync_gradients: |
861 | params_to_clip = [unet.parameters()] | 848 | params_to_clip = [unet.parameters()] |
862 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | 849 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: |
@@ -866,9 +853,17 @@ def main(): | |||
866 | @torch.no_grad() | 853 | @torch.no_grad() |
867 | def on_after_optimize(lr: float): | 854 | def on_after_optimize(lr: float): |
868 | if not args.train_text_encoder: | 855 | if not args.train_text_encoder: |
869 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 856 | text_encoder.text_model.embeddings.normalize( |
857 | args.decay_target, | ||
858 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) | ||
859 | ) | ||
860 | |||
861 | def on_log(): | ||
862 | if args.use_ema: | ||
863 | return {"ema_decay": ema_unet.decay} | ||
864 | return {} | ||
870 | 865 | ||
871 | loop = partial( | 866 | loss_step_ = partial( |
872 | loss_step, | 867 | loss_step, |
873 | vae, | 868 | vae, |
874 | noise_scheduler, | 869 | noise_scheduler, |
@@ -879,8 +874,25 @@ def main(): | |||
879 | args.seed, | 874 | args.seed, |
880 | ) | 875 | ) |
881 | 876 | ||
882 | # We need to initialize the trackers we use, and also store our configuration. | 877 | checkpointer = Checkpointer( |
883 | # The trackers initializes automatically on the main process. | 878 | weight_dtype=weight_dtype, |
879 | datamodule=datamodule, | ||
880 | accelerator=accelerator, | ||
881 | vae=vae, | ||
882 | unet=unet, | ||
883 | ema_unet=ema_unet, | ||
884 | tokenizer=tokenizer, | ||
885 | text_encoder=text_encoder, | ||
886 | scheduler=checkpoint_scheduler, | ||
887 | output_dir=basepath, | ||
888 | placeholder_token=args.placeholder_token, | ||
889 | placeholder_token_id=placeholder_token_id, | ||
890 | sample_image_size=args.sample_image_size, | ||
891 | sample_batch_size=args.sample_batch_size, | ||
892 | sample_batches=args.sample_batches, | ||
893 | seed=args.seed | ||
894 | ) | ||
895 | |||
884 | if accelerator.is_main_process: | 896 | if accelerator.is_main_process: |
885 | config = vars(args).copy() | 897 | config = vars(args).copy() |
886 | config["initializer_token"] = " ".join(config["initializer_token"]) | 898 | config["initializer_token"] = " ".join(config["initializer_token"]) |
@@ -898,9 +910,9 @@ def main(): | |||
898 | optimizer, | 910 | optimizer, |
899 | train_dataloader, | 911 | train_dataloader, |
900 | val_dataloader, | 912 | val_dataloader, |
901 | loop, | 913 | loss_step_, |
902 | on_train=tokenizer.train, | 914 | on_train=on_train, |
903 | on_eval=tokenizer.eval, | 915 | on_eval=on_eval, |
904 | on_before_optimize=on_before_optimize, | 916 | on_before_optimize=on_before_optimize, |
905 | on_after_optimize=on_after_optimize, | 917 | on_after_optimize=on_after_optimize, |
906 | ) | 918 | ) |
@@ -909,182 +921,28 @@ def main(): | |||
909 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 921 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
910 | plt.close() | 922 | plt.close() |
911 | 923 | ||
912 | quit() | 924 | return |
913 | |||
914 | # Train! | ||
915 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | ||
916 | |||
917 | logger.info("***** Running training *****") | ||
918 | logger.info(f" Num Epochs = {num_epochs}") | ||
919 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | ||
920 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | ||
921 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | ||
922 | logger.info(f" Total optimization steps = {args.max_train_steps}") | ||
923 | # Only show the progress bar once on each machine. | ||
924 | |||
925 | global_step = 0 | ||
926 | |||
927 | avg_loss = AverageMeter() | ||
928 | avg_acc = AverageMeter() | ||
929 | 925 | ||
930 | avg_loss_val = AverageMeter() | 926 | train_loop( |
931 | avg_acc_val = AverageMeter() | ||
932 | |||
933 | max_acc_val = 0.0 | ||
934 | |||
935 | checkpointer = Checkpointer( | ||
936 | weight_dtype=weight_dtype, | ||
937 | datamodule=datamodule, | ||
938 | accelerator=accelerator, | 927 | accelerator=accelerator, |
939 | vae=vae, | 928 | optimizer=optimizer, |
940 | unet=unet, | 929 | lr_scheduler=lr_scheduler, |
941 | ema_unet=ema_unet, | 930 | model=unet, |
942 | tokenizer=tokenizer, | 931 | checkpointer=checkpointer, |
943 | text_encoder=text_encoder, | 932 | train_dataloader=train_dataloader, |
944 | scheduler=checkpoint_scheduler, | 933 | val_dataloader=val_dataloader, |
945 | output_dir=basepath, | 934 | loss_step=loss_step_, |
946 | placeholder_token=args.placeholder_token, | 935 | sample_frequency=args.sample_frequency, |
947 | placeholder_token_id=placeholder_token_id, | 936 | sample_steps=args.sample_steps, |
948 | sample_image_size=args.sample_image_size, | 937 | checkpoint_frequency=args.checkpoint_frequency, |
949 | sample_batch_size=args.sample_batch_size, | 938 | global_step_offset=0, |
950 | sample_batches=args.sample_batches, | 939 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
951 | seed=args.seed | 940 | num_epochs=args.num_train_epochs, |
952 | ) | 941 | on_log=on_log, |
953 | 942 | on_train=on_train, | |
954 | local_progress_bar = tqdm( | 943 | on_after_optimize=on_after_optimize, |
955 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 944 | on_eval=on_eval |
956 | disable=not accelerator.is_local_main_process, | ||
957 | dynamic_ncols=True | ||
958 | ) | ||
959 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
960 | |||
961 | global_progress_bar = tqdm( | ||
962 | range(args.max_train_steps + val_steps), | ||
963 | disable=not accelerator.is_local_main_process, | ||
964 | dynamic_ncols=True | ||
965 | ) | 945 | ) |
966 | global_progress_bar.set_description("Total progress") | ||
967 | |||
968 | try: | ||
969 | for epoch in range(num_epochs): | ||
970 | if accelerator.is_main_process: | ||
971 | if epoch % args.sample_frequency == 0: | ||
972 | checkpointer.save_samples(global_step, args.sample_steps) | ||
973 | |||
974 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
975 | local_progress_bar.reset() | ||
976 | |||
977 | unet.train() | ||
978 | if epoch < args.train_text_encoder_epochs: | ||
979 | text_encoder.train() | ||
980 | elif epoch == args.train_text_encoder_epochs: | ||
981 | text_encoder.requires_grad_(False) | ||
982 | |||
983 | with on_train(): | ||
984 | for step, batch in enumerate(train_dataloader): | ||
985 | with accelerator.accumulate(unet): | ||
986 | loss, acc, bsz = loop(step, batch) | ||
987 | |||
988 | accelerator.backward(loss) | ||
989 | |||
990 | on_before_optimize() | ||
991 | |||
992 | optimizer.step() | ||
993 | if not accelerator.optimizer_step_was_skipped: | ||
994 | lr_scheduler.step() | ||
995 | if args.use_ema: | ||
996 | ema_unet.step(unet.parameters()) | ||
997 | optimizer.zero_grad(set_to_none=True) | ||
998 | |||
999 | avg_loss.update(loss.detach_(), bsz) | ||
1000 | avg_acc.update(acc.detach_(), bsz) | ||
1001 | |||
1002 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
1003 | if accelerator.sync_gradients: | ||
1004 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
1005 | |||
1006 | local_progress_bar.update(1) | ||
1007 | global_progress_bar.update(1) | ||
1008 | |||
1009 | global_step += 1 | ||
1010 | |||
1011 | logs = { | ||
1012 | "train/loss": avg_loss.avg.item(), | ||
1013 | "train/acc": avg_acc.avg.item(), | ||
1014 | "train/cur_loss": loss.item(), | ||
1015 | "train/cur_acc": acc.item(), | ||
1016 | "lr": lr_scheduler.get_last_lr()[0] | ||
1017 | } | ||
1018 | if args.use_ema: | ||
1019 | logs["ema_decay"] = 1 - ema_unet.decay | ||
1020 | |||
1021 | accelerator.log(logs, step=global_step) | ||
1022 | |||
1023 | local_progress_bar.set_postfix(**logs) | ||
1024 | |||
1025 | if global_step >= args.max_train_steps: | ||
1026 | break | ||
1027 | |||
1028 | accelerator.wait_for_everyone() | ||
1029 | |||
1030 | unet.eval() | ||
1031 | text_encoder.eval() | ||
1032 | |||
1033 | cur_loss_val = AverageMeter() | ||
1034 | cur_acc_val = AverageMeter() | ||
1035 | |||
1036 | with torch.inference_mode(): | ||
1037 | with on_eval(): | ||
1038 | for step, batch in enumerate(val_dataloader): | ||
1039 | loss, acc, bsz = loop(step, batch, True) | ||
1040 | |||
1041 | loss = loss.detach_() | ||
1042 | acc = acc.detach_() | ||
1043 | |||
1044 | cur_loss_val.update(loss, bsz) | ||
1045 | cur_acc_val.update(acc, bsz) | ||
1046 | |||
1047 | avg_loss_val.update(loss, bsz) | ||
1048 | avg_acc_val.update(acc, bsz) | ||
1049 | |||
1050 | local_progress_bar.update(1) | ||
1051 | global_progress_bar.update(1) | ||
1052 | |||
1053 | logs = { | ||
1054 | "val/loss": avg_loss_val.avg.item(), | ||
1055 | "val/acc": avg_acc_val.avg.item(), | ||
1056 | "val/cur_loss": loss.item(), | ||
1057 | "val/cur_acc": acc.item(), | ||
1058 | } | ||
1059 | local_progress_bar.set_postfix(**logs) | ||
1060 | |||
1061 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
1062 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
1063 | |||
1064 | accelerator.log(logs, step=global_step) | ||
1065 | |||
1066 | local_progress_bar.clear() | ||
1067 | global_progress_bar.clear() | ||
1068 | |||
1069 | if accelerator.is_main_process: | ||
1070 | if avg_acc_val.avg.item() > max_acc_val: | ||
1071 | accelerator.print( | ||
1072 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
1073 | max_acc_val = avg_acc_val.avg.item() | ||
1074 | |||
1075 | # Create the pipeline using using the trained modules and save it. | ||
1076 | if accelerator.is_main_process: | ||
1077 | print("Finished! Saving final checkpoint and resume state.") | ||
1078 | checkpointer.save_samples(global_step, args.sample_steps) | ||
1079 | checkpointer.save_model() | ||
1080 | accelerator.end_training() | ||
1081 | |||
1082 | except KeyboardInterrupt: | ||
1083 | if accelerator.is_main_process: | ||
1084 | print("Interrupted, saving checkpoint and resume state...") | ||
1085 | checkpointer.save_model() | ||
1086 | accelerator.end_training() | ||
1087 | quit() | ||
1088 | 946 | ||
1089 | 947 | ||
1090 | if __name__ == "__main__": | 948 | if __name__ == "__main__": |