summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
commit7b149930bb53b93db74106ad20a30abf4b114f9b (patch)
tree67c2ccbce2a9838ad8a020ee527b19113e67e30a /train_ti.py
parentAdded TI decay start offset (diff)
downloadtextual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py268
1 files changed, 53 insertions, 215 deletions
diff --git a/train_ti.py b/train_ti.py
index e18ee38..8c86586 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -21,11 +21,10 @@ from slugify import slugify
21from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
24from training.common import loss_step, generate_class_images, get_scheduler 24from training.common import loss_step, train_loop, generate_class_images, get_scheduler
25from training.lr import LRFinder 25from training.lr import LRFinder
26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args 26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
27from models.clip.embeddings import patch_managed_embeddings 27from models.clip.embeddings import patch_managed_embeddings
28from models.clip.prompt import PromptProcessor
29from models.clip.tokenizer import MultiCLIPTokenizer 28from models.clip.tokenizer import MultiCLIPTokenizer
30 29
31logger = get_logger(__name__) 30logger = get_logger(__name__)
@@ -198,12 +197,6 @@ def parse_args():
198 default=100 197 default=100
199 ) 198 )
200 parser.add_argument( 199 parser.add_argument(
201 "--max_train_steps",
202 type=int,
203 default=None,
204 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
205 )
206 parser.add_argument(
207 "--gradient_accumulation_steps", 200 "--gradient_accumulation_steps",
208 type=int, 201 type=int,
209 default=1, 202 default=1,
@@ -409,7 +402,7 @@ def parse_args():
409 ) 402 )
410 parser.add_argument( 403 parser.add_argument(
411 "--decay_target", 404 "--decay_target",
412 default=0.4, 405 default=None,
413 type=float, 406 type=float,
414 help="Embedding decay target." 407 help="Embedding decay target."
415 ) 408 )
@@ -668,8 +661,6 @@ def main():
668 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 661 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
669 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 662 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
670 663
671 prompt_processor = PromptProcessor(tokenizer, text_encoder)
672
673 if args.scale_lr: 664 if args.scale_lr:
674 args.learning_rate = ( 665 args.learning_rate = (
675 args.learning_rate * args.gradient_accumulation_steps * 666 args.learning_rate * args.gradient_accumulation_steps *
@@ -722,7 +713,7 @@ def main():
722 datamodule = VlpnDataModule( 713 datamodule = VlpnDataModule(
723 data_file=args.train_data_file, 714 data_file=args.train_data_file,
724 batch_size=args.train_batch_size, 715 batch_size=args.train_batch_size,
725 prompt_processor=prompt_processor, 716 tokenizer=tokenizer,
726 class_subdir=args.class_image_dir, 717 class_subdir=args.class_image_dir,
727 num_class_images=args.num_class_images, 718 num_class_images=args.num_class_images,
728 size=args.resolution, 719 size=args.resolution,
@@ -759,13 +750,7 @@ def main():
759 args.sample_steps 750 args.sample_steps
760 ) 751 )
761 752
762 # Scheduler and math around the number of training steps.
763 overrode_max_train_steps = False
764 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 753 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
765 if args.max_train_steps is None:
766 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
767 overrode_max_train_steps = True
768 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
769 754
770 if args.find_lr: 755 if args.find_lr:
771 lr_scheduler = None 756 lr_scheduler = None
@@ -781,7 +766,7 @@ def main():
781 annealing_exp=args.lr_annealing_exp, 766 annealing_exp=args.lr_annealing_exp,
782 cycles=args.lr_cycles, 767 cycles=args.lr_cycles,
783 warmup_epochs=args.lr_warmup_epochs, 768 warmup_epochs=args.lr_warmup_epochs,
784 max_train_steps=args.max_train_steps, 769 num_train_epochs=args.num_train_epochs,
785 num_update_steps_per_epoch=num_update_steps_per_epoch, 770 num_update_steps_per_epoch=num_update_steps_per_epoch,
786 gradient_accumulation_steps=args.gradient_accumulation_steps 771 gradient_accumulation_steps=args.gradient_accumulation_steps
787 ) 772 )
@@ -805,15 +790,6 @@ def main():
805 else: 790 else:
806 unet.eval() 791 unet.eval()
807 792
808 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
809 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
810 if overrode_max_train_steps:
811 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
812
813 num_val_steps_per_epoch = len(val_dataloader)
814 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
815 val_steps = num_val_steps_per_epoch * num_epochs
816
817 @contextmanager 793 @contextmanager
818 def on_train(): 794 def on_train():
819 try: 795 try:
@@ -842,19 +818,44 @@ def main():
842 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) 818 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
843 ) 819 )
844 820
821 if args.use_ema:
822 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
823
824 def on_log():
825 if args.use_ema:
826 return {"ema_decay": ema_embeddings.decay}
827 return {}
828
845 loop = partial( 829 loop = partial(
846 loss_step, 830 loss_step,
847 vae, 831 vae,
848 noise_scheduler, 832 noise_scheduler,
849 unet, 833 unet,
850 prompt_processor, 834 text_encoder,
851 args.num_class_images, 835 args.num_class_images,
852 args.prior_loss_weight, 836 args.prior_loss_weight,
853 args.seed, 837 args.seed,
854 ) 838 )
855 839
856 # We need to initialize the trackers we use, and also store our configuration. 840 checkpointer = Checkpointer(
857 # The trackers initializes automatically on the main process. 841 weight_dtype=weight_dtype,
842 datamodule=datamodule,
843 accelerator=accelerator,
844 vae=vae,
845 unet=unet,
846 tokenizer=tokenizer,
847 text_encoder=text_encoder,
848 ema_embeddings=ema_embeddings,
849 scheduler=checkpoint_scheduler,
850 placeholder_token=args.placeholder_token,
851 new_ids=new_ids,
852 output_dir=basepath,
853 sample_image_size=args.sample_image_size,
854 sample_batch_size=args.sample_batch_size,
855 sample_batches=args.sample_batches,
856 seed=args.seed
857 )
858
858 if accelerator.is_main_process: 859 if accelerator.is_main_process:
859 config = vars(args).copy() 860 config = vars(args).copy()
860 config["initializer_token"] = " ".join(config["initializer_token"]) 861 config["initializer_token"] = " ".join(config["initializer_token"])
@@ -882,190 +883,27 @@ def main():
882 883
883 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 884 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
884 plt.close() 885 plt.close()
885 886 else:
886 quit() 887 train_loop(
887 888 accelerator=accelerator,
888 # Train! 889 optimizer=optimizer,
889 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 890 lr_scheduler=lr_scheduler,
890 891 model=text_encoder,
891 logger.info("***** Running training *****") 892 checkpointer=checkpointer,
892 logger.info(f" Num Epochs = {num_epochs}") 893 train_dataloader=train_dataloader,
893 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 894 val_dataloader=val_dataloader,
894 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 895 loss_step=loop,
895 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 896 sample_frequency=args.sample_frequency,
896 logger.info(f" Total optimization steps = {args.max_train_steps}") 897 sample_steps=args.sample_steps,
897 # Only show the progress bar once on each machine. 898 checkpoint_frequency=args.checkpoint_frequency,
898 899 global_step_offset=global_step_offset,
899 global_step = 0 900 gradient_accumulation_steps=args.gradient_accumulation_steps,
900 901 num_epochs=args.num_train_epochs,
901 avg_loss = AverageMeter() 902 on_log=on_log,
902 avg_acc = AverageMeter() 903 on_train=on_train,
903 904 on_after_optimize=on_after_optimize,
904 avg_loss_val = AverageMeter() 905 on_eval=on_eval
905 avg_acc_val = AverageMeter() 906 )
906
907 max_acc_val = 0.0
908
909 checkpointer = Checkpointer(
910 weight_dtype=weight_dtype,
911 datamodule=datamodule,
912 accelerator=accelerator,
913 vae=vae,
914 unet=unet,
915 tokenizer=tokenizer,
916 text_encoder=text_encoder,
917 ema_embeddings=ema_embeddings,
918 scheduler=checkpoint_scheduler,
919 placeholder_token=args.placeholder_token,
920 new_ids=new_ids,
921 output_dir=basepath,
922 sample_image_size=args.sample_image_size,
923 sample_batch_size=args.sample_batch_size,
924 sample_batches=args.sample_batches,
925 seed=args.seed
926 )
927
928 local_progress_bar = tqdm(
929 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
930 disable=not accelerator.is_local_main_process,
931 dynamic_ncols=True
932 )
933 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
934
935 global_progress_bar = tqdm(
936 range(args.max_train_steps + val_steps),
937 disable=not accelerator.is_local_main_process,
938 dynamic_ncols=True
939 )
940 global_progress_bar.set_description("Total progress")
941
942 try:
943 for epoch in range(num_epochs):
944 if accelerator.is_main_process:
945 if epoch % args.sample_frequency == 0:
946 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
947
948 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
949 local_progress_bar.reset()
950
951 text_encoder.train()
952
953 with on_train():
954 for step, batch in enumerate(train_dataloader):
955 with accelerator.accumulate(text_encoder):
956 loss, acc, bsz = loop(step, batch)
957
958 accelerator.backward(loss)
959
960 optimizer.step()
961 lr_scheduler.step()
962 optimizer.zero_grad(set_to_none=True)
963
964 avg_loss.update(loss.detach_(), bsz)
965 avg_acc.update(acc.detach_(), bsz)
966
967 # Checks if the accelerator has performed an optimization step behind the scenes
968 if accelerator.sync_gradients:
969 on_after_optimize(lr_scheduler.get_last_lr()[0])
970
971 if args.use_ema:
972 ema_embeddings.step(
973 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
974
975 local_progress_bar.update(1)
976 global_progress_bar.update(1)
977
978 global_step += 1
979
980 logs = {
981 "train/loss": avg_loss.avg.item(),
982 "train/acc": avg_acc.avg.item(),
983 "train/cur_loss": loss.item(),
984 "train/cur_acc": acc.item(),
985 "lr": lr_scheduler.get_last_lr()[0],
986 }
987 if args.use_ema:
988 logs["ema_decay"] = ema_embeddings.decay
989
990 accelerator.log(logs, step=global_step)
991
992 local_progress_bar.set_postfix(**logs)
993
994 if global_step >= args.max_train_steps:
995 break
996
997 accelerator.wait_for_everyone()
998
999 text_encoder.eval()
1000
1001 cur_loss_val = AverageMeter()
1002 cur_acc_val = AverageMeter()
1003
1004 with torch.inference_mode():
1005 with on_eval():
1006 for step, batch in enumerate(val_dataloader):
1007 loss, acc, bsz = loop(step, batch, True)
1008
1009 loss = loss.detach_()
1010 acc = acc.detach_()
1011
1012 cur_loss_val.update(loss, bsz)
1013 cur_acc_val.update(acc, bsz)
1014
1015 avg_loss_val.update(loss, bsz)
1016 avg_acc_val.update(acc, bsz)
1017
1018 local_progress_bar.update(1)
1019 global_progress_bar.update(1)
1020
1021 logs = {
1022 "val/loss": avg_loss_val.avg.item(),
1023 "val/acc": avg_acc_val.avg.item(),
1024 "val/cur_loss": loss.item(),
1025 "val/cur_acc": acc.item(),
1026 }
1027 local_progress_bar.set_postfix(**logs)
1028
1029 logs["val/cur_loss"] = cur_loss_val.avg.item()
1030 logs["val/cur_acc"] = cur_acc_val.avg.item()
1031
1032 accelerator.log(logs, step=global_step)
1033
1034 local_progress_bar.clear()
1035 global_progress_bar.clear()
1036
1037 if accelerator.is_main_process:
1038 if avg_acc_val.avg.item() > max_acc_val:
1039 accelerator.print(
1040 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1041 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
1042 max_acc_val = avg_acc_val.avg.item()
1043
1044 if (epoch + 1) % args.checkpoint_frequency == 0:
1045 checkpointer.checkpoint(global_step + global_step_offset, "training")
1046 save_args(basepath, args, {
1047 "global_step": global_step + global_step_offset
1048 })
1049
1050 # Create the pipeline using using the trained modules and save it.
1051 if accelerator.is_main_process:
1052 print("Finished! Saving final checkpoint and resume state.")
1053 checkpointer.checkpoint(global_step + global_step_offset, "end")
1054 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1055 save_args(basepath, args, {
1056 "global_step": global_step + global_step_offset
1057 })
1058 accelerator.end_training()
1059
1060 except KeyboardInterrupt:
1061 if accelerator.is_main_process:
1062 print("Interrupted, saving checkpoint and resume state...")
1063 checkpointer.checkpoint(global_step + global_step_offset, "end")
1064 save_args(basepath, args, {
1065 "global_step": global_step + global_step_offset
1066 })
1067 accelerator.end_training()
1068 quit()
1069 907
1070 908
1071if __name__ == "__main__": 909if __name__ == "__main__":