summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
commit127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch)
tree61cb98adbf33ed08506601f8b70f1b62bc42c4ee
parentSimplified step calculations (diff)
downloadtextual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip
More modularization
-rw-r--r--models/clip/embeddings.py6
-rw-r--r--train_dreambooth.py272
-rw-r--r--train_ti.py479
-rw-r--r--training/common.py260
-rw-r--r--training/lr.py14
-rw-r--r--training/modules/dreambooth.py0
-rw-r--r--training/modules/lora.py0
-rw-r--r--training/modules/ti.py284
-rw-r--r--training/util.py15
9 files changed, 677 insertions, 653 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 761efbc..9a23a2a 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -40,8 +40,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 self.position_embedding = embeddings.position_embedding 40 self.position_embedding = embeddings.position_embedding
41 self.initializer_factor = config.initializer_factor 41 self.initializer_factor = config.initializer_factor
42 42
43 self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item()
44
45 self.temp_token_embedding = nn.Embedding( 43 self.temp_token_embedding = nn.Embedding(
46 self.token_embedding.num_embeddings, 44 self.token_embedding.num_embeddings,
47 self.token_embedding.embedding_dim, 45 self.token_embedding.embedding_dim,
@@ -101,9 +99,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
101 99
102 return embeds 100 return embeds
103 101
104 def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): 102 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
105 if target is None:
106 target = self.decay_target
107 w = self.temp_token_embedding.weight 103 w = self.temp_token_embedding.weight
108 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
109 w[self.temp_token_ids] = F.normalize( 105 w[self.temp_token_ids] = F.normalize(
diff --git a/train_dreambooth.py b/train_dreambooth.py
index fbbe6c2..c892ebf 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -1,6 +1,5 @@
1import argparse 1import argparse
2import itertools 2import itertools
3import math
4import datetime 3import datetime
5import logging 4import logging
6from pathlib import Path 5from pathlib import Path
@@ -16,16 +15,15 @@ from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17import matplotlib.pyplot as plt 16import matplotlib.pyplot as plt
18from diffusers.training_utils import EMAModel 17from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm
20from transformers import CLIPTextModel 18from transformers import CLIPTextModel
21from slugify import slugify 19from slugify import slugify
22 20
23from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
26from training.common import loss_step, generate_class_images, get_scheduler 24from training.common import loss_step, train_loop, generate_class_images, get_scheduler
27from training.lr import LRFinder 25from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 26from training.util import CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings 27from models.clip.embeddings import patch_managed_embeddings
30from models.clip.tokenizer import MultiCLIPTokenizer 28from models.clip.tokenizer import MultiCLIPTokenizer
31 29
@@ -292,7 +290,7 @@ def parse_args():
292 parser.add_argument( 290 parser.add_argument(
293 "--lr_min_lr", 291 "--lr_min_lr",
294 type=float, 292 type=float,
295 default=None, 293 default=0.04,
296 help="Minimum learning rate in the lr scheduler." 294 help="Minimum learning rate in the lr scheduler."
297 ) 295 )
298 parser.add_argument( 296 parser.add_argument(
@@ -787,14 +785,6 @@ def main():
787 args.sample_steps 785 args.sample_steps
788 ) 786 )
789 787
790 # Scheduler and math around the number of training steps.
791 overrode_max_train_steps = False
792 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
793 if args.max_train_steps is None:
794 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
795 overrode_max_train_steps = True
796 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
797
798 if args.find_lr: 788 if args.find_lr:
799 lr_scheduler = None 789 lr_scheduler = None
800 else: 790 else:
@@ -802,15 +792,14 @@ def main():
802 args.lr_scheduler, 792 args.lr_scheduler,
803 optimizer=optimizer, 793 optimizer=optimizer,
804 min_lr=args.lr_min_lr, 794 min_lr=args.lr_min_lr,
805 lr=args.learning_rate,
806 warmup_func=args.lr_warmup_func, 795 warmup_func=args.lr_warmup_func,
807 annealing_func=args.lr_annealing_func, 796 annealing_func=args.lr_annealing_func,
808 warmup_exp=args.lr_warmup_exp, 797 warmup_exp=args.lr_warmup_exp,
809 annealing_exp=args.lr_annealing_exp, 798 annealing_exp=args.lr_annealing_exp,
810 cycles=args.lr_cycles, 799 cycles=args.lr_cycles,
800 train_epochs=args.num_train_epochs,
811 warmup_epochs=args.lr_warmup_epochs, 801 warmup_epochs=args.lr_warmup_epochs,
812 max_train_steps=args.max_train_steps, 802 num_training_steps_per_epoch=len(train_dataloader),
813 num_update_steps_per_epoch=num_update_steps_per_epoch,
814 gradient_accumulation_steps=args.gradient_accumulation_steps 803 gradient_accumulation_steps=args.gradient_accumulation_steps
815 ) 804 )
816 805
@@ -827,19 +816,16 @@ def main():
827 if args.use_ema: 816 if args.use_ema:
828 ema_unet.to(accelerator.device) 817 ema_unet.to(accelerator.device)
829 818
830 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
831 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
832 if overrode_max_train_steps:
833 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
834
835 num_val_steps_per_epoch = len(val_dataloader)
836 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
837 val_steps = num_val_steps_per_epoch * num_epochs
838
839 @contextmanager 819 @contextmanager
840 def on_train(): 820 def on_train(epoch: int):
841 try: 821 try:
842 tokenizer.train() 822 tokenizer.train()
823
824 if epoch < args.train_text_encoder_epochs:
825 text_encoder.train()
826 elif epoch == args.train_text_encoder_epochs:
827 text_encoder.requires_grad_(False)
828
843 yield 829 yield
844 finally: 830 finally:
845 pass 831 pass
@@ -848,6 +834,7 @@ def main():
848 def on_eval(): 834 def on_eval():
849 try: 835 try:
850 tokenizer.eval() 836 tokenizer.eval()
837 text_encoder.eval()
851 838
852 ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() 839 ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext()
853 840
@@ -856,7 +843,7 @@ def main():
856 finally: 843 finally:
857 pass 844 pass
858 845
859 def on_before_optimize(): 846 def on_before_optimize(epoch: int):
860 if accelerator.sync_gradients: 847 if accelerator.sync_gradients:
861 params_to_clip = [unet.parameters()] 848 params_to_clip = [unet.parameters()]
862 if args.train_text_encoder and epoch < args.train_text_encoder_epochs: 849 if args.train_text_encoder and epoch < args.train_text_encoder_epochs:
@@ -866,9 +853,17 @@ def main():
866 @torch.no_grad() 853 @torch.no_grad()
867 def on_after_optimize(lr: float): 854 def on_after_optimize(lr: float):
868 if not args.train_text_encoder: 855 if not args.train_text_encoder:
869 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) 856 text_encoder.text_model.embeddings.normalize(
857 args.decay_target,
858 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
859 )
860
861 def on_log():
862 if args.use_ema:
863 return {"ema_decay": ema_unet.decay}
864 return {}
870 865
871 loop = partial( 866 loss_step_ = partial(
872 loss_step, 867 loss_step,
873 vae, 868 vae,
874 noise_scheduler, 869 noise_scheduler,
@@ -879,8 +874,25 @@ def main():
879 args.seed, 874 args.seed,
880 ) 875 )
881 876
882 # We need to initialize the trackers we use, and also store our configuration. 877 checkpointer = Checkpointer(
883 # The trackers initializes automatically on the main process. 878 weight_dtype=weight_dtype,
879 datamodule=datamodule,
880 accelerator=accelerator,
881 vae=vae,
882 unet=unet,
883 ema_unet=ema_unet,
884 tokenizer=tokenizer,
885 text_encoder=text_encoder,
886 scheduler=checkpoint_scheduler,
887 output_dir=basepath,
888 placeholder_token=args.placeholder_token,
889 placeholder_token_id=placeholder_token_id,
890 sample_image_size=args.sample_image_size,
891 sample_batch_size=args.sample_batch_size,
892 sample_batches=args.sample_batches,
893 seed=args.seed
894 )
895
884 if accelerator.is_main_process: 896 if accelerator.is_main_process:
885 config = vars(args).copy() 897 config = vars(args).copy()
886 config["initializer_token"] = " ".join(config["initializer_token"]) 898 config["initializer_token"] = " ".join(config["initializer_token"])
@@ -898,9 +910,9 @@ def main():
898 optimizer, 910 optimizer,
899 train_dataloader, 911 train_dataloader,
900 val_dataloader, 912 val_dataloader,
901 loop, 913 loss_step_,
902 on_train=tokenizer.train, 914 on_train=on_train,
903 on_eval=tokenizer.eval, 915 on_eval=on_eval,
904 on_before_optimize=on_before_optimize, 916 on_before_optimize=on_before_optimize,
905 on_after_optimize=on_after_optimize, 917 on_after_optimize=on_after_optimize,
906 ) 918 )
@@ -909,182 +921,28 @@ def main():
909 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 921 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
910 plt.close() 922 plt.close()
911 923
912 quit() 924 return
913
914 # Train!
915 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
916
917 logger.info("***** Running training *****")
918 logger.info(f" Num Epochs = {num_epochs}")
919 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
920 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
921 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
922 logger.info(f" Total optimization steps = {args.max_train_steps}")
923 # Only show the progress bar once on each machine.
924
925 global_step = 0
926
927 avg_loss = AverageMeter()
928 avg_acc = AverageMeter()
929 925
930 avg_loss_val = AverageMeter() 926 train_loop(
931 avg_acc_val = AverageMeter()
932
933 max_acc_val = 0.0
934
935 checkpointer = Checkpointer(
936 weight_dtype=weight_dtype,
937 datamodule=datamodule,
938 accelerator=accelerator, 927 accelerator=accelerator,
939 vae=vae, 928 optimizer=optimizer,
940 unet=unet, 929 lr_scheduler=lr_scheduler,
941 ema_unet=ema_unet, 930 model=unet,
942 tokenizer=tokenizer, 931 checkpointer=checkpointer,
943 text_encoder=text_encoder, 932 train_dataloader=train_dataloader,
944 scheduler=checkpoint_scheduler, 933 val_dataloader=val_dataloader,
945 output_dir=basepath, 934 loss_step=loss_step_,
946 placeholder_token=args.placeholder_token, 935 sample_frequency=args.sample_frequency,
947 placeholder_token_id=placeholder_token_id, 936 sample_steps=args.sample_steps,
948 sample_image_size=args.sample_image_size, 937 checkpoint_frequency=args.checkpoint_frequency,
949 sample_batch_size=args.sample_batch_size, 938 global_step_offset=0,
950 sample_batches=args.sample_batches, 939 gradient_accumulation_steps=args.gradient_accumulation_steps,
951 seed=args.seed 940 num_epochs=args.num_train_epochs,
952 ) 941 on_log=on_log,
953 942 on_train=on_train,
954 local_progress_bar = tqdm( 943 on_after_optimize=on_after_optimize,
955 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 944 on_eval=on_eval
956 disable=not accelerator.is_local_main_process,
957 dynamic_ncols=True
958 )
959 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
960
961 global_progress_bar = tqdm(
962 range(args.max_train_steps + val_steps),
963 disable=not accelerator.is_local_main_process,
964 dynamic_ncols=True
965 ) 945 )
966 global_progress_bar.set_description("Total progress")
967
968 try:
969 for epoch in range(num_epochs):
970 if accelerator.is_main_process:
971 if epoch % args.sample_frequency == 0:
972 checkpointer.save_samples(global_step, args.sample_steps)
973
974 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
975 local_progress_bar.reset()
976
977 unet.train()
978 if epoch < args.train_text_encoder_epochs:
979 text_encoder.train()
980 elif epoch == args.train_text_encoder_epochs:
981 text_encoder.requires_grad_(False)
982
983 with on_train():
984 for step, batch in enumerate(train_dataloader):
985 with accelerator.accumulate(unet):
986 loss, acc, bsz = loop(step, batch)
987
988 accelerator.backward(loss)
989
990 on_before_optimize()
991
992 optimizer.step()
993 if not accelerator.optimizer_step_was_skipped:
994 lr_scheduler.step()
995 if args.use_ema:
996 ema_unet.step(unet.parameters())
997 optimizer.zero_grad(set_to_none=True)
998
999 avg_loss.update(loss.detach_(), bsz)
1000 avg_acc.update(acc.detach_(), bsz)
1001
1002 # Checks if the accelerator has performed an optimization step behind the scenes
1003 if accelerator.sync_gradients:
1004 on_after_optimize(lr_scheduler.get_last_lr()[0])
1005
1006 local_progress_bar.update(1)
1007 global_progress_bar.update(1)
1008
1009 global_step += 1
1010
1011 logs = {
1012 "train/loss": avg_loss.avg.item(),
1013 "train/acc": avg_acc.avg.item(),
1014 "train/cur_loss": loss.item(),
1015 "train/cur_acc": acc.item(),
1016 "lr": lr_scheduler.get_last_lr()[0]
1017 }
1018 if args.use_ema:
1019 logs["ema_decay"] = 1 - ema_unet.decay
1020
1021 accelerator.log(logs, step=global_step)
1022
1023 local_progress_bar.set_postfix(**logs)
1024
1025 if global_step >= args.max_train_steps:
1026 break
1027
1028 accelerator.wait_for_everyone()
1029
1030 unet.eval()
1031 text_encoder.eval()
1032
1033 cur_loss_val = AverageMeter()
1034 cur_acc_val = AverageMeter()
1035
1036 with torch.inference_mode():
1037 with on_eval():
1038 for step, batch in enumerate(val_dataloader):
1039 loss, acc, bsz = loop(step, batch, True)
1040
1041 loss = loss.detach_()
1042 acc = acc.detach_()
1043
1044 cur_loss_val.update(loss, bsz)
1045 cur_acc_val.update(acc, bsz)
1046
1047 avg_loss_val.update(loss, bsz)
1048 avg_acc_val.update(acc, bsz)
1049
1050 local_progress_bar.update(1)
1051 global_progress_bar.update(1)
1052
1053 logs = {
1054 "val/loss": avg_loss_val.avg.item(),
1055 "val/acc": avg_acc_val.avg.item(),
1056 "val/cur_loss": loss.item(),
1057 "val/cur_acc": acc.item(),
1058 }
1059 local_progress_bar.set_postfix(**logs)
1060
1061 logs["val/cur_loss"] = cur_loss_val.avg.item()
1062 logs["val/cur_acc"] = cur_acc_val.avg.item()
1063
1064 accelerator.log(logs, step=global_step)
1065
1066 local_progress_bar.clear()
1067 global_progress_bar.clear()
1068
1069 if accelerator.is_main_process:
1070 if avg_acc_val.avg.item() > max_acc_val:
1071 accelerator.print(
1072 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1073 max_acc_val = avg_acc_val.avg.item()
1074
1075 # Create the pipeline using using the trained modules and save it.
1076 if accelerator.is_main_process:
1077 print("Finished! Saving final checkpoint and resume state.")
1078 checkpointer.save_samples(global_step, args.sample_steps)
1079 checkpointer.save_model()
1080 accelerator.end_training()
1081
1082 except KeyboardInterrupt:
1083 if accelerator.is_main_process:
1084 print("Interrupted, saving checkpoint and resume state...")
1085 checkpointer.save_model()
1086 accelerator.end_training()
1087 quit()
1088 946
1089 947
1090if __name__ == "__main__": 948if __name__ == "__main__":
diff --git a/train_ti.py b/train_ti.py
index 3f4e739..3a55f40 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,31 +1,15 @@
1import argparse 1import argparse
2import math
3import datetime
4import logging
5from functools import partial
6from pathlib import Path
7from contextlib import contextmanager, nullcontext
8 2
9import torch 3import torch
10import torch.utils.checkpoint 4import torch.utils.checkpoint
11 5
12from accelerate import Accelerator
13from accelerate.logging import get_logger 6from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 7
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 8from util import load_config
16import matplotlib.pyplot as plt 9from data.csv import VlpnDataItem
17from tqdm.auto import tqdm 10from training.common import train_setup
18from transformers import CLIPTextModel 11from training.modules.ti import train_ti
19from slugify import slugify 12from training.util import save_args
20
21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import VlpnDataModule, VlpnDataItem
24from training.common import loss_step, train_loop, generate_class_images, get_scheduler
25from training.lr import LRFinder
26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
27from models.clip.embeddings import patch_managed_embeddings
28from models.clip.tokenizer import MultiCLIPTokenizer
29 13
30logger = get_logger(__name__) 14logger = get_logger(__name__)
31 15
@@ -271,7 +255,7 @@ def parse_args():
271 parser.add_argument( 255 parser.add_argument(
272 "--lr_min_lr", 256 "--lr_min_lr",
273 type=float, 257 type=float,
274 default=None, 258 default=0.04,
275 help="Minimum learning rate in the lr scheduler." 259 help="Minimum learning rate in the lr scheduler."
276 ) 260 )
277 parser.add_argument( 261 parser.add_argument(
@@ -401,19 +385,19 @@ def parse_args():
401 help="The weight of prior preservation loss." 385 help="The weight of prior preservation loss."
402 ) 386 )
403 parser.add_argument( 387 parser.add_argument(
404 "--decay_target", 388 "--emb_decay_target",
405 default=None, 389 default=0.4,
406 type=float, 390 type=float,
407 help="Embedding decay target." 391 help="Embedding decay target."
408 ) 392 )
409 parser.add_argument( 393 parser.add_argument(
410 "--decay_factor", 394 "--emb_decay_factor",
411 default=1, 395 default=1,
412 type=float, 396 type=float,
413 help="Embedding decay factor." 397 help="Embedding decay factor."
414 ) 398 )
415 parser.add_argument( 399 parser.add_argument(
416 "--decay_start", 400 "--emb_decay_start",
417 default=1e-4, 401 default=1e-4,
418 type=float, 402 type=float,
419 help="Embedding decay start offset." 403 help="Embedding decay start offset."
@@ -491,213 +475,10 @@ def parse_args():
491 return args 475 return args
492 476
493 477
494class Checkpointer(CheckpointerBase):
495 def __init__(
496 self,
497 weight_dtype,
498 accelerator: Accelerator,
499 vae: AutoencoderKL,
500 unet: UNet2DConditionModel,
501 tokenizer: MultiCLIPTokenizer,
502 text_encoder: CLIPTextModel,
503 ema_embeddings: EMAModel,
504 scheduler,
505 placeholder_token,
506 new_ids,
507 *args,
508 **kwargs
509 ):
510 super().__init__(*args, **kwargs)
511
512 self.weight_dtype = weight_dtype
513 self.accelerator = accelerator
514 self.vae = vae
515 self.unet = unet
516 self.tokenizer = tokenizer
517 self.text_encoder = text_encoder
518 self.ema_embeddings = ema_embeddings
519 self.scheduler = scheduler
520 self.placeholder_token = placeholder_token
521 self.new_ids = new_ids
522
523 @torch.no_grad()
524 def checkpoint(self, step, postfix):
525 print("Saving checkpoint for step %d..." % step)
526
527 checkpoints_path = self.output_dir.joinpath("checkpoints")
528 checkpoints_path.mkdir(parents=True, exist_ok=True)
529
530 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
531
532 ema_context = self.ema_embeddings.apply_temporary(
533 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
534
535 with ema_context:
536 for (token, ids) in zip(self.placeholder_token, self.new_ids):
537 text_encoder.text_model.embeddings.save_embed(
538 ids,
539 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
540 )
541
542 del text_encoder
543
544 @torch.no_grad()
545 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
546 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
547
548 ema_context = self.ema_embeddings.apply_temporary(
549 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
550
551 with ema_context:
552 orig_dtype = text_encoder.dtype
553 text_encoder.to(dtype=self.weight_dtype)
554
555 pipeline = VlpnStableDiffusion(
556 text_encoder=text_encoder,
557 vae=self.vae,
558 unet=self.unet,
559 tokenizer=self.tokenizer,
560 scheduler=self.scheduler,
561 ).to(self.accelerator.device)
562 pipeline.set_progress_bar_config(dynamic_ncols=True)
563
564 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
565
566 text_encoder.to(dtype=orig_dtype)
567
568 del text_encoder
569 del pipeline
570
571 if torch.cuda.is_available():
572 torch.cuda.empty_cache()
573
574
575def main(): 478def main():
576 args = parse_args() 479 args = parse_args()
577 480
578 global_step_offset = args.global_step 481 def data_filter(item: VlpnDataItem):
579 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
580 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
581 basepath.mkdir(parents=True, exist_ok=True)
582
583 accelerator = Accelerator(
584 log_with=LoggerType.TENSORBOARD,
585 logging_dir=f"{basepath}",
586 gradient_accumulation_steps=args.gradient_accumulation_steps,
587 mixed_precision=args.mixed_precision
588 )
589
590 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
591
592 args.seed = args.seed or (torch.random.seed() >> 32)
593 set_seed(args.seed)
594
595 save_args(basepath, args)
596
597 # Load the tokenizer and add the placeholder token as a additional special token
598 if args.tokenizer_name:
599 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
600 elif args.pretrained_model_name_or_path:
601 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
602 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
603 tokenizer.set_dropout(args.vector_dropout)
604
605 # Load models and create wrapper for stable diffusion
606 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
607 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
608 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
609 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
610 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
611 args.pretrained_model_name_or_path, subfolder='scheduler')
612
613 vae.enable_slicing()
614 vae.set_use_memory_efficient_attention_xformers(True)
615 unet.set_use_memory_efficient_attention_xformers(True)
616
617 if args.gradient_checkpointing:
618 unet.enable_gradient_checkpointing()
619 text_encoder.gradient_checkpointing_enable()
620
621 embeddings = patch_managed_embeddings(text_encoder)
622 ema_embeddings = None
623
624 if args.embeddings_dir is not None:
625 embeddings_dir = Path(args.embeddings_dir)
626 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
627 raise ValueError("--embeddings_dir must point to an existing directory")
628
629 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
630 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
631
632 # Convert the initializer_token, placeholder_token to ids
633 initializer_token_ids = [
634 tokenizer.encode(token, add_special_tokens=False)
635 for token in args.initializer_token
636 ]
637
638 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
639 embeddings.resize(len(tokenizer))
640
641 for (new_id, init_ids) in zip(new_ids, initializer_token_ids):
642 embeddings.add_embed(new_id, init_ids)
643
644 init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)]
645
646 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
647
648 if args.use_ema:
649 ema_embeddings = EMAModel(
650 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
651 inv_gamma=args.ema_inv_gamma,
652 power=args.ema_power,
653 max_value=args.ema_max_decay,
654 )
655
656 vae.requires_grad_(False)
657 unet.requires_grad_(False)
658
659 text_encoder.text_model.encoder.requires_grad_(False)
660 text_encoder.text_model.final_layer_norm.requires_grad_(False)
661 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
662 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
663
664 if args.scale_lr:
665 args.learning_rate = (
666 args.learning_rate * args.gradient_accumulation_steps *
667 args.train_batch_size * accelerator.num_processes
668 )
669
670 if args.find_lr:
671 args.learning_rate = 1e-5
672
673 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
674 if args.use_8bit_adam:
675 try:
676 import bitsandbytes as bnb
677 except ImportError:
678 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
679
680 optimizer_class = bnb.optim.AdamW8bit
681 else:
682 optimizer_class = torch.optim.AdamW
683
684 # Initialize the optimizer
685 optimizer = optimizer_class(
686 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings
687 lr=args.learning_rate,
688 betas=(args.adam_beta1, args.adam_beta2),
689 weight_decay=args.adam_weight_decay,
690 eps=args.adam_epsilon,
691 amsgrad=args.adam_amsgrad,
692 )
693
694 weight_dtype = torch.float32
695 if args.mixed_precision == "fp16":
696 weight_dtype = torch.float16
697 elif args.mixed_precision == "bf16":
698 weight_dtype = torch.bfloat16
699
700 def keyword_filter(item: VlpnDataItem):
701 cond1 = any( 482 cond1 = any(
702 keyword in part 483 keyword in part
703 for keyword in args.placeholder_token 484 for keyword in args.placeholder_token
@@ -710,198 +491,78 @@ def main():
710 ) 491 )
711 return cond1 and cond3 and cond4 492 return cond1 and cond3 and cond4
712 493
713 datamodule = VlpnDataModule( 494 setup = train_setup(
495 output_dir=args.output_dir,
496 project=args.project,
497 pretrained_model_name_or_path=args.pretrained_model_name_or_path,
498 learning_rate=args.learning_rate,
714 data_file=args.train_data_file, 499 data_file=args.train_data_file,
715 batch_size=args.train_batch_size, 500 gradient_accumulation_steps=args.gradient_accumulation_steps,
716 tokenizer=tokenizer, 501 mixed_precision=args.mixed_precision,
717 class_subdir=args.class_image_dir, 502 seed=args.seed,
503 vector_shuffle=args.vector_shuffle,
504 vector_dropout=args.vector_dropout,
505 gradient_checkpointing=args.gradient_checkpointing,
506 embeddings_dir=args.embeddings_dir,
507 placeholder_token=args.placeholder_token,
508 initializer_token=args.initializer_token,
509 num_vectors=args.num_vectors,
510 scale_lr=args.scale_lr,
511 use_8bit_adam=args.use_8bit_adam,
512 train_batch_size=args.train_batch_size,
513 class_image_dir=args.class_image_dir,
718 num_class_images=args.num_class_images, 514 num_class_images=args.num_class_images,
719 size=args.resolution, 515 resolution=args.resolution,
720 num_buckets=args.num_buckets, 516 num_buckets=args.num_buckets,
721 progressive_buckets=args.progressive_buckets, 517 progressive_buckets=args.progressive_buckets,
722 bucket_step_size=args.bucket_step_size, 518 bucket_step_size=args.bucket_step_size,
723 bucket_max_pixels=args.bucket_max_pixels, 519 bucket_max_pixels=args.bucket_max_pixels,
724 dropout=args.tag_dropout, 520 tag_dropout=args.tag_dropout,
725 shuffle=not args.no_tag_shuffle, 521 tag_shuffle=not args.no_tag_shuffle,
726 template_key=args.train_data_template, 522 data_template=args.train_data_template,
727 valid_set_size=args.valid_set_size, 523 valid_set_size=args.valid_set_size,
728 valid_set_repeat=args.valid_set_repeat, 524 valid_set_repeat=args.valid_set_repeat,
729 num_workers=args.dataloader_num_workers, 525 data_filter=data_filter,
730 seed=args.seed, 526 sample_image_size=args.sample_image_size,
731 filter=keyword_filter, 527 sample_batch_size=args.sample_batch_size,
732 dtype=weight_dtype 528 sample_steps=args.sample_steps,
733 )
734 datamodule.setup()
735
736 train_dataloader = datamodule.train_dataloader
737 val_dataloader = datamodule.val_dataloader
738
739 if args.num_class_images != 0:
740 generate_class_images(
741 accelerator,
742 text_encoder,
743 vae,
744 unet,
745 tokenizer,
746 checkpoint_scheduler,
747 datamodule.data_train,
748 args.sample_batch_size,
749 args.sample_image_size,
750 args.sample_steps
751 )
752
753 if args.find_lr:
754 lr_scheduler = None
755 else:
756 lr_scheduler = get_scheduler(
757 args.lr_scheduler,
758 optimizer=optimizer,
759 min_lr=args.lr_min_lr,
760 lr=args.learning_rate,
761 warmup_func=args.lr_warmup_func,
762 annealing_func=args.lr_annealing_func,
763 warmup_exp=args.lr_warmup_exp,
764 annealing_exp=args.lr_annealing_exp,
765 cycles=args.lr_cycles,
766 train_epochs=args.num_train_epochs,
767 warmup_epochs=args.lr_warmup_epochs,
768 num_training_steps_per_epoch=len(train_dataloader),
769 gradient_accumulation_steps=args.gradient_accumulation_steps
770 )
771
772 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
773 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
774 ) 529 )
775 530
776 # Move vae and unet to device 531 save_args(setup.output_dir, args)
777 vae.to(accelerator.device, dtype=weight_dtype)
778 unet.to(accelerator.device, dtype=weight_dtype)
779
780 if args.use_ema:
781 ema_embeddings.to(accelerator.device)
782 532
783 # Keep vae and unet in eval mode as we don't train these 533 train_ti(
784 vae.eval() 534 setup=setup,
785 535 num_train_epochs=args.num_train_epochs,
786 if args.gradient_checkpointing: 536 num_class_images=args.num_class_images,
787 unet.train() 537 prior_loss_weight=args.prior_loss_weight,
788 else: 538 use_ema=args.use_ema,
789 unet.eval() 539 ema_inv_gamma=args.ema_inv_gamma,
790 540 ema_power=args.ema_power,
791 @contextmanager 541 ema_max_decay=args.ema_max_decay,
792 def on_train(): 542 adam_beta1=args.adam_beta1,
793 try: 543 adam_beta2=args.adam_beta2,
794 tokenizer.train() 544 adam_weight_decay=args.adam_weight_decay,
795 yield 545 adam_epsilon=args.adam_epsilon,
796 finally: 546 adam_amsgrad=args.adam_amsgrad,
797 pass 547 lr_scheduler=args.lr_scheduler,
798 548 lr_min_lr=args.lr_min_lr,
799 @contextmanager 549 lr_warmup_func=args.lr_warmup_func,
800 def on_eval(): 550 lr_annealing_func=args.lr_annealing_func,
801 try: 551 lr_warmup_exp=args.lr_warmup_exp,
802 tokenizer.eval() 552 lr_annealing_exp=args.lr_annealing_exp,
803 553 lr_cycles=args.lr_cycles,
804 ema_context = ema_embeddings.apply_temporary( 554 lr_warmup_epochs=args.lr_warmup_epochs,
805 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() 555 emb_decay_target=args.emb_decay_target,
806 556 emb_decay_factor=args.emb_decay_factor,
807 with ema_context: 557 emb_decay_start=args.emb_decay_start,
808 yield
809 finally:
810 pass
811
812 @torch.no_grad()
813 def on_after_optimize(lr: float):
814 text_encoder.text_model.embeddings.normalize(
815 args.decay_target,
816 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
817 )
818
819 if args.use_ema:
820 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
821
822 def on_log():
823 if args.use_ema:
824 return {"ema_decay": ema_embeddings.decay}
825 return {}
826
827 loss_step_ = partial(
828 loss_step,
829 vae,
830 noise_scheduler,
831 unet,
832 text_encoder,
833 args.num_class_images != 0,
834 args.prior_loss_weight,
835 args.seed,
836 )
837
838 checkpointer = Checkpointer(
839 weight_dtype=weight_dtype,
840 datamodule=datamodule,
841 accelerator=accelerator,
842 vae=vae,
843 unet=unet,
844 tokenizer=tokenizer,
845 text_encoder=text_encoder,
846 ema_embeddings=ema_embeddings,
847 scheduler=checkpoint_scheduler,
848 placeholder_token=args.placeholder_token,
849 new_ids=new_ids,
850 output_dir=basepath,
851 sample_image_size=args.sample_image_size, 558 sample_image_size=args.sample_image_size,
852 sample_batch_size=args.sample_batch_size, 559 sample_batch_size=args.sample_batch_size,
853 sample_batches=args.sample_batches, 560 sample_batches=args.sample_batches,
854 seed=args.seed 561 sample_frequency=args.sample_frequency,
855 ) 562 sample_steps=args.sample_steps,
856 563 checkpoint_frequency=args.checkpoint_frequency,
857 if accelerator.is_main_process: 564 global_step_offset=args.global_step,
858 config = vars(args).copy() 565 )
859 config["initializer_token"] = " ".join(config["initializer_token"])
860 config["placeholder_token"] = " ".join(config["placeholder_token"])
861 config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]])
862 if config["collection"] is not None:
863 config["collection"] = " ".join(config["collection"])
864 if config["exclude_collections"] is not None:
865 config["exclude_collections"] = " ".join(config["exclude_collections"])
866 accelerator.init_trackers("textual_inversion", config=config)
867
868 if args.find_lr:
869 lr_finder = LRFinder(
870 accelerator=accelerator,
871 optimizer=optimizer,
872 model=text_encoder,
873 train_dataloader=train_dataloader,
874 val_dataloader=val_dataloader,
875 loss_step=loss_step_,
876 on_train=on_train,
877 on_eval=on_eval,
878 on_after_optimize=on_after_optimize,
879 )
880 lr_finder.run(num_epochs=100, end_lr=1e3)
881
882 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
883 plt.close()
884 else:
885 train_loop(
886 accelerator=accelerator,
887 optimizer=optimizer,
888 lr_scheduler=lr_scheduler,
889 model=text_encoder,
890 checkpointer=checkpointer,
891 train_dataloader=train_dataloader,
892 val_dataloader=val_dataloader,
893 loss_step=loss_step_,
894 sample_frequency=args.sample_frequency,
895 sample_steps=args.sample_steps,
896 checkpoint_frequency=args.checkpoint_frequency,
897 global_step_offset=global_step_offset,
898 gradient_accumulation_steps=args.gradient_accumulation_steps,
899 num_epochs=args.num_train_epochs,
900 on_log=on_log,
901 on_train=on_train,
902 on_after_optimize=on_after_optimize,
903 on_eval=on_eval
904 )
905 566
906 567
907if __name__ == "__main__": 568if __name__ == "__main__":
diff --git a/training/common.py b/training/common.py
index 180396e..73ce814 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,46 +1,77 @@
1import math 1import math
2from pathlib import Path
2from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 4from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple
5import datetime
6import logging
4 7
5import torch 8import torch
6import torch.nn.functional as F 9import torch.nn.functional as F
7from torch.utils.data import DataLoader 10from torch.utils.data import DataLoader
8 11
9from accelerate import Accelerator 12from accelerate import Accelerator
10from transformers import CLIPTokenizer, CLIPTextModel 13from accelerate.utils import LoggerType, set_seed
11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
13 17
14from tqdm.auto import tqdm 18from tqdm.auto import tqdm
19from slugify import slugify
15 20
21from data.csv import VlpnDataModule, VlpnDataItem
22from util import load_embeddings_from_dir
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from models.clip.embeddings import patch_managed_embeddings
17from models.clip.util import get_extended_embeddings 25from models.clip.util import get_extended_embeddings
26from models.clip.tokenizer import MultiCLIPTokenizer
18from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
19from training.util import AverageMeter, CheckpointerBase 28from training.util import AverageMeter, CheckpointerBase
20 29
21 30
31class TrainingSetup(NamedTuple):
32 accelerator: Accelerator
33 tokenizer: MultiCLIPTokenizer
34 text_encoder: CLIPTextModel
35 vae: AutoencoderKL
36 unet: UNet2DConditionModel
37 noise_scheduler: DDPMScheduler
38 checkpoint_scheduler: DPMSolverMultistepScheduler
39 optimizer_class: Callable
40 learning_rate: float
41 weight_dtype: torch.dtype
42 output_dir: Path
43 seed: int
44 train_dataloader: DataLoader
45 val_dataloader: DataLoader
46 placeholder_token: list[str]
47 placeholder_token_ids: list[list[int]]
48
49
22def noop(*args, **kwards): 50def noop(*args, **kwards):
23 pass 51 pass
24 52
25 53
54def noop_ctx(*args, **kwards):
55 return nullcontext()
56
57
26def noop_on_log(): 58def noop_on_log():
27 return {} 59 return {}
28 60
29 61
30def get_scheduler( 62def get_scheduler(
31 id: str, 63 id: str,
32 min_lr: float,
33 lr: float,
34 warmup_func: str,
35 annealing_func: str,
36 warmup_exp: int,
37 annealing_exp: int,
38 cycles: int,
39 train_epochs: int,
40 warmup_epochs: int,
41 optimizer: torch.optim.Optimizer, 64 optimizer: torch.optim.Optimizer,
42 num_training_steps_per_epoch: int, 65 num_training_steps_per_epoch: int,
43 gradient_accumulation_steps: int, 66 gradient_accumulation_steps: int,
67 min_lr: float = 0.04,
68 warmup_func: str = "cos",
69 annealing_func: str = "cos",
70 warmup_exp: int = 1,
71 annealing_exp: int = 1,
72 cycles: int = 1,
73 train_epochs: int = 100,
74 warmup_epochs: int = 10,
44): 75):
45 num_training_steps_per_epoch = math.ceil( 76 num_training_steps_per_epoch = math.ceil(
46 num_training_steps_per_epoch / gradient_accumulation_steps 77 num_training_steps_per_epoch / gradient_accumulation_steps
@@ -49,8 +80,6 @@ def get_scheduler(
49 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch 80 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
50 81
51 if id == "one_cycle": 82 if id == "one_cycle":
52 min_lr = 0.04 if min_lr is None else min_lr / lr
53
54 lr_scheduler = get_one_cycle_schedule( 83 lr_scheduler = get_one_cycle_schedule(
55 optimizer=optimizer, 84 optimizer=optimizer,
56 num_training_steps=num_training_steps, 85 num_training_steps=num_training_steps,
@@ -133,6 +162,196 @@ def generate_class_images(
133 torch.cuda.empty_cache() 162 torch.cuda.empty_cache()
134 163
135 164
165def train_setup(
166 output_dir: str,
167 project: str,
168 pretrained_model_name_or_path: str,
169 learning_rate: float,
170 data_file: str,
171 gradient_accumulation_steps: int = 1,
172 mixed_precision: Literal["no", "fp16", "bf16"] = "no",
173 seed: Optional[int] = None,
174 vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
175 vector_dropout: float = 0.1,
176 gradient_checkpointing: bool = True,
177 embeddings_dir: Optional[str] = None,
178 placeholder_token: list[str] = [],
179 initializer_token: list[str] = [],
180 num_vectors: int = 1,
181 scale_lr: bool = False,
182 use_8bit_adam: bool = False,
183 train_batch_size: int = 1,
184 class_image_dir: Optional[str] = None,
185 num_class_images: int = 0,
186 resolution: int = 768,
187 num_buckets: int = 0,
188 progressive_buckets: bool = False,
189 bucket_step_size: int = 64,
190 bucket_max_pixels: Optional[int] = None,
191 tag_dropout: float = 0.1,
192 tag_shuffle: bool = True,
193 data_template: str = "template",
194 valid_set_size: Optional[int] = None,
195 valid_set_repeat: int = 1,
196 data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 sample_batch_size: int = 1,
198 sample_image_size: int = 768,
199 sample_steps: int = 20,
200) -> TrainingSetup:
201 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
202 output_dir = Path(output_dir).joinpath(slugify(project), now)
203 output_dir.mkdir(parents=True, exist_ok=True)
204
205 accelerator = Accelerator(
206 log_with=LoggerType.TENSORBOARD,
207 logging_dir=f"{output_dir}",
208 gradient_accumulation_steps=gradient_accumulation_steps,
209 mixed_precision=mixed_precision
210 )
211
212 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
213
214 seed = seed or (torch.random.seed() >> 32)
215 set_seed(seed)
216
217 # Load the tokenizer and add the placeholder token as a additional special token
218 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
219 tokenizer.set_use_vector_shuffle(vector_shuffle)
220 tokenizer.set_dropout(vector_dropout)
221
222 # Load models and create wrapper for stable diffusion
223 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
224 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
225 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
226 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
227 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
228 pretrained_model_name_or_path, subfolder='scheduler')
229
230 vae.enable_slicing()
231 vae.set_use_memory_efficient_attention_xformers(True)
232 unet.set_use_memory_efficient_attention_xformers(True)
233
234 if gradient_checkpointing:
235 unet.enable_gradient_checkpointing()
236 text_encoder.gradient_checkpointing_enable()
237
238 embeddings = patch_managed_embeddings(text_encoder)
239
240 if embeddings_dir is not None:
241 embeddings_dir = Path(embeddings_dir)
242 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
243 raise ValueError("--embeddings_dir must point to an existing directory")
244
245 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
246 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
247
248 # Convert the initializer_token, placeholder_token to ids
249 initializer_token_ids = [
250 tokenizer.encode(token, add_special_tokens=False)
251 for token in initializer_token
252 ]
253
254 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
255 embeddings.resize(len(tokenizer))
256
257 for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids):
258 embeddings.add_embed(new_id, init_ids)
259
260 init_ratios = [
261 f"{len(init_ids)} / {len(new_id)}"
262 for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
263 ]
264
265 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")
266
267 vae.requires_grad_(False)
268 unet.requires_grad_(False)
269 text_encoder.requires_grad_(False)
270
271 if scale_lr:
272 learning_rate = (
273 learning_rate * gradient_accumulation_steps *
274 train_batch_size * accelerator.num_processes
275 )
276
277 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
278 if use_8bit_adam:
279 try:
280 import bitsandbytes as bnb
281 except ImportError:
282 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
283
284 optimizer_class = bnb.optim.AdamW8bit
285 else:
286 optimizer_class = torch.optim.AdamW
287
288 weight_dtype = torch.float32
289 if mixed_precision == "fp16":
290 weight_dtype = torch.float16
291 elif mixed_precision == "bf16":
292 weight_dtype = torch.bfloat16
293
294 datamodule = VlpnDataModule(
295 data_file=data_file,
296 batch_size=train_batch_size,
297 tokenizer=tokenizer,
298 class_subdir=class_image_dir,
299 num_class_images=num_class_images,
300 size=resolution,
301 num_buckets=num_buckets,
302 progressive_buckets=progressive_buckets,
303 bucket_step_size=bucket_step_size,
304 bucket_max_pixels=bucket_max_pixels,
305 dropout=tag_dropout,
306 shuffle=tag_shuffle,
307 template_key=data_template,
308 valid_set_size=valid_set_size,
309 valid_set_repeat=valid_set_repeat,
310 seed=seed,
311 filter=data_filter,
312 dtype=weight_dtype
313 )
314 datamodule.setup()
315
316 train_dataloader = datamodule.train_dataloader
317 val_dataloader = datamodule.val_dataloader
318
319 train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)
320
321 if num_class_images != 0:
322 generate_class_images(
323 accelerator,
324 text_encoder,
325 vae,
326 unet,
327 tokenizer,
328 checkpoint_scheduler,
329 datamodule.data_train,
330 sample_batch_size,
331 sample_image_size,
332 sample_steps
333 )
334
335 return TrainingSetup(
336 accelerator=accelerator,
337 tokenizer=tokenizer,
338 text_encoder=text_encoder,
339 vae=vae,
340 unet=unet,
341 noise_scheduler=noise_scheduler,
342 checkpoint_scheduler=checkpoint_scheduler,
343 optimizer_class=optimizer_class,
344 learning_rate=learning_rate,
345 output_dir=output_dir,
346 weight_dtype=weight_dtype,
347 seed=seed,
348 train_dataloader=train_dataloader,
349 val_dataloader=val_dataloader,
350 placeholder_token=placeholder_token,
351 placeholder_token_ids=placeholder_token_ids
352 )
353
354
136def loss_step( 355def loss_step(
137 vae: AutoencoderKL, 356 vae: AutoencoderKL,
138 noise_scheduler: DDPMScheduler, 357 noise_scheduler: DDPMScheduler,
@@ -221,15 +440,14 @@ def train_loop(
221 sample_steps: int = 20, 440 sample_steps: int = 20,
222 checkpoint_frequency: int = 50, 441 checkpoint_frequency: int = 50,
223 global_step_offset: int = 0, 442 global_step_offset: int = 0,
224 gradient_accumulation_steps: int = 1,
225 num_epochs: int = 100, 443 num_epochs: int = 100,
226 on_log: Callable[[], dict[str, Any]] = noop_on_log, 444 on_log: Callable[[], dict[str, Any]] = noop_on_log,
227 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 445 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
228 on_before_optimize: Callable[[], None] = noop, 446 on_before_optimize: Callable[[int], None] = noop,
229 on_after_optimize: Callable[[float], None] = noop, 447 on_after_optimize: Callable[[float], None] = noop,
230 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 448 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
231): 449):
232 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 450 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
233 num_val_steps_per_epoch = len(val_dataloader) 451 num_val_steps_per_epoch = len(val_dataloader)
234 452
235 num_training_steps = num_training_steps_per_epoch * num_epochs 453 num_training_steps = num_training_steps_per_epoch * num_epochs
@@ -273,14 +491,14 @@ def train_loop(
273 491
274 model.train() 492 model.train()
275 493
276 with on_train(): 494 with on_train(epoch):
277 for step, batch in enumerate(train_dataloader): 495 for step, batch in enumerate(train_dataloader):
278 with accelerator.accumulate(model): 496 with accelerator.accumulate(model):
279 loss, acc, bsz = loss_step(step, batch) 497 loss, acc, bsz = loss_step(step, batch)
280 498
281 accelerator.backward(loss) 499 accelerator.backward(loss)
282 500
283 on_before_optimize() 501 on_before_optimize(epoch)
284 502
285 optimizer.step() 503 optimizer.step()
286 lr_scheduler.step() 504 lr_scheduler.step()
diff --git a/training/lr.py b/training/lr.py
index 84e30a0..7584ba2 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -16,6 +16,10 @@ def noop(*args, **kwards):
16 pass 16 pass
17 17
18 18
19def noop_ctx(*args, **kwards):
20 return nullcontext()
21
22
19class LRFinder(): 23class LRFinder():
20 def __init__( 24 def __init__(
21 self, 25 self,
@@ -25,10 +29,10 @@ class LRFinder():
25 train_dataloader, 29 train_dataloader,
26 val_dataloader, 30 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 32 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
29 on_before_optimize: Callable[[], None] = noop, 33 on_before_optimize: Callable[[int], None] = noop,
30 on_after_optimize: Callable[[float], None] = noop, 34 on_after_optimize: Callable[[float], None] = noop,
31 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 35 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
32 ): 36 ):
33 self.accelerator = accelerator 37 self.accelerator = accelerator
34 self.model = model 38 self.model = model
@@ -86,7 +90,7 @@ class LRFinder():
86 90
87 self.model.train() 91 self.model.train()
88 92
89 with self.on_train(): 93 with self.on_train(epoch):
90 for step, batch in enumerate(self.train_dataloader): 94 for step, batch in enumerate(self.train_dataloader):
91 if step >= num_train_batches: 95 if step >= num_train_batches:
92 break 96 break
@@ -96,7 +100,7 @@ class LRFinder():
96 100
97 self.accelerator.backward(loss) 101 self.accelerator.backward(loss)
98 102
99 self.on_before_optimize() 103 self.on_before_optimize(epoch)
100 104
101 self.optimizer.step() 105 self.optimizer.step()
102 lr_scheduler.step() 106 lr_scheduler.step()
diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/modules/dreambooth.py
diff --git a/training/modules/lora.py b/training/modules/lora.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/modules/lora.py
diff --git a/training/modules/ti.py b/training/modules/ti.py
new file mode 100644
index 0000000..2db6f88
--- /dev/null
+++ b/training/modules/ti.py
@@ -0,0 +1,284 @@
1from typing import Literal
2from functools import partial
3from contextlib import contextmanager, nullcontext
4
5import torch
6
7from slugify import slugify
8
9from accelerate import Accelerator
10from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel
12
13from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
14from models.clip.tokenizer import MultiCLIPTokenizer
15
16from training.common import TrainingSetup, get_scheduler, train_loop, loss_step
17from training.util import EMAModel, CheckpointerBase
18
19
20class Checkpointer(CheckpointerBase):
21 def __init__(
22 self,
23 accelerator: Accelerator,
24 vae: AutoencoderKL,
25 unet: UNet2DConditionModel,
26 tokenizer: MultiCLIPTokenizer,
27 text_encoder: CLIPTextModel,
28 ema_embeddings: EMAModel,
29 weight_dtype: torch.dtype,
30 scheduler,
31 placeholder_token,
32 placeholder_token_ids,
33 *args,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.weight_dtype = weight_dtype
39 self.accelerator = accelerator
40 self.vae = vae
41 self.unet = unet
42 self.tokenizer = tokenizer
43 self.text_encoder = text_encoder
44 self.ema_embeddings = ema_embeddings
45 self.scheduler = scheduler
46 self.placeholder_token = placeholder_token
47 self.placeholder_token_ids = placeholder_token_ids
48
49 @torch.no_grad()
50 def checkpoint(self, step, postfix):
51 print("Saving checkpoint for step %d..." % step)
52
53 checkpoints_path = self.output_dir.joinpath("checkpoints")
54 checkpoints_path.mkdir(parents=True, exist_ok=True)
55
56 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
57
58 ema_context = nullcontext()
59 if self.ema_embeddings is not None:
60 ema_context = self.ema_embeddings.apply_temporary(
61 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
62
63 with ema_context:
64 for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids):
65 text_encoder.text_model.embeddings.save_embed(
66 ids,
67 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
68 )
69
70 del text_encoder
71
72 @torch.no_grad()
73 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
74 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
75
76 ema_context = nullcontext()
77 if self.ema_embeddings is not None:
78 ema_context = self.ema_embeddings.apply_temporary(
79 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
80
81 with ema_context:
82 orig_dtype = text_encoder.dtype
83 text_encoder.to(dtype=self.weight_dtype)
84
85 pipeline = VlpnStableDiffusion(
86 text_encoder=text_encoder,
87 vae=self.vae,
88 unet=self.unet,
89 tokenizer=self.tokenizer,
90 scheduler=self.scheduler,
91 ).to(self.accelerator.device)
92 pipeline.set_progress_bar_config(dynamic_ncols=True)
93
94 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
95
96 text_encoder.to(dtype=orig_dtype)
97
98 del text_encoder
99 del pipeline
100
101 if torch.cuda.is_available():
102 torch.cuda.empty_cache()
103
104
105def train_ti(
106 setup: TrainingSetup,
107 num_train_epochs: int = 100,
108 num_class_images: int = 0,
109 prior_loss_weight: float = 1.0,
110 use_ema: bool = False,
111 ema_inv_gamma: float = 1.0,
112 ema_power: float = 4/5,
113 ema_max_decay: float = .9999,
114 adam_beta1: float = 0.9,
115 adam_beta2: float = 0.999,
116 adam_weight_decay: float = 0,
117 adam_epsilon: float = 1e-08,
118 adam_amsgrad: bool = False,
119 lr_scheduler: Literal[
120 "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle"
121 ] = "one_cycle",
122 lr_min_lr: float = 0.04,
123 lr_warmup_func: Literal["linear", "cos"] = "cos",
124 lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos",
125 lr_warmup_exp: int = 1,
126 lr_annealing_exp: int = 1,
127 lr_cycles: int = 1,
128 lr_warmup_epochs: int = 10,
129 emb_decay_target: float = 0.4,
130 emb_decay_factor: float = 1,
131 emb_decay_start: float = 1e-4,
132 sample_image_size: int = 768,
133 sample_batch_size: int = 1,
134 sample_batches: int = 1,
135 sample_frequency: int = 10,
136 sample_steps: int = 20,
137 checkpoint_frequency: int = 50,
138 global_step_offset: int = 0,
139):
140 if use_ema:
141 ema_embeddings = EMAModel(
142 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
143 inv_gamma=ema_inv_gamma,
144 power=ema_power,
145 max_value=ema_max_decay,
146 )
147 else:
148 ema_embeddings = None
149
150 setup.text_encoder.requires_grad_(True)
151 setup.text_encoder.text_model.encoder.requires_grad_(False)
152 setup.text_encoder.text_model.final_layer_norm.requires_grad_(False)
153 setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
154 setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
155
156 # Initialize the optimizer
157 optimizer = setup.optimizer_class(
158 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
159 lr=setup.learning_rate,
160 betas=(adam_beta1, adam_beta2),
161 weight_decay=adam_weight_decay,
162 eps=adam_epsilon,
163 amsgrad=adam_amsgrad,
164 )
165
166 lr_scheduler = get_scheduler(
167 lr_scheduler,
168 optimizer=optimizer,
169 min_lr=lr_min_lr,
170 warmup_func=lr_warmup_func,
171 annealing_func=lr_annealing_func,
172 warmup_exp=lr_warmup_exp,
173 annealing_exp=lr_annealing_exp,
174 cycles=lr_cycles,
175 train_epochs=num_train_epochs,
176 warmup_epochs=lr_warmup_epochs,
177 num_training_steps_per_epoch=len(setup.train_dataloader),
178 gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps
179 )
180
181 text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare(
182 setup.text_encoder, optimizer, lr_scheduler
183 )
184
185 # Move vae and unet to device
186 setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype)
187 setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype)
188
189 if use_ema:
190 ema_embeddings.to(setup.accelerator.device)
191
192 setup.unet.train()
193
194 @contextmanager
195 def on_train(epoch: int):
196 try:
197 setup.tokenizer.train()
198 yield
199 finally:
200 pass
201
202 @contextmanager
203 def on_eval():
204 try:
205 setup.tokenizer.eval()
206
207 ema_context = nullcontext()
208 if use_ema:
209 ema_context = ema_embeddings.apply_temporary(
210 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
211
212 with ema_context:
213 yield
214 finally:
215 pass
216
217 @torch.no_grad()
218 def on_after_optimize(lr: float):
219 text_encoder.text_model.embeddings.normalize(
220 emb_decay_target,
221 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start))))
222 )
223
224 if use_ema:
225 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
226
227 def on_log():
228 if use_ema:
229 return {"ema_decay": ema_embeddings.decay}
230 return {}
231
232 loss_step_ = partial(
233 loss_step,
234 setup.vae,
235 setup.noise_scheduler,
236 setup.unet,
237 text_encoder,
238 num_class_images != 0,
239 prior_loss_weight,
240 setup.seed,
241 )
242
243 checkpointer = Checkpointer(
244 accelerator=setup.accelerator,
245 vae=setup.vae,
246 unet=setup.unet,
247 tokenizer=setup.tokenizer,
248 text_encoder=text_encoder,
249 ema_embeddings=ema_embeddings,
250 weight_dtype=setup.weight_dtype,
251 scheduler=setup.checkpoint_scheduler,
252 placeholder_token=setup.placeholder_token,
253 placeholder_token_ids=setup.placeholder_token_ids,
254 train_dataloader=setup.train_dataloader,
255 val_dataloader=setup.val_dataloader,
256 output_dir=setup.output_dir,
257 seed=setup.seed,
258 sample_image_size=sample_image_size,
259 sample_batch_size=sample_batch_size,
260 sample_batches=sample_batches
261 )
262
263 if setup.accelerator.is_main_process:
264 setup.accelerator.init_trackers("textual_inversion")
265
266 train_loop(
267 accelerator=setup.accelerator,
268 optimizer=optimizer,
269 lr_scheduler=lr_scheduler,
270 model=text_encoder,
271 checkpointer=checkpointer,
272 train_dataloader=setup.train_dataloader,
273 val_dataloader=setup.val_dataloader,
274 loss_step=loss_step_,
275 sample_frequency=sample_frequency,
276 sample_steps=sample_steps,
277 checkpoint_frequency=checkpoint_frequency,
278 global_step_offset=global_step_offset,
279 num_epochs=num_train_epochs,
280 on_log=on_log,
281 on_train=on_train,
282 on_after_optimize=on_after_optimize,
283 on_eval=on_eval
284 )
diff --git a/training/util.py b/training/util.py
index 0ec2032..cc4cdee 100644
--- a/training/util.py
+++ b/training/util.py
@@ -41,14 +41,16 @@ class AverageMeter:
41class CheckpointerBase: 41class CheckpointerBase:
42 def __init__( 42 def __init__(
43 self, 43 self,
44 datamodule, 44 train_dataloader,
45 val_dataloader,
45 output_dir: Path, 46 output_dir: Path,
46 sample_image_size: int, 47 sample_image_size: int,
47 sample_batches: int, 48 sample_batches: int,
48 sample_batch_size: int, 49 sample_batch_size: int,
49 seed: Optional[int] = None 50 seed: Optional[int] = None
50 ): 51 ):
51 self.datamodule = datamodule 52 self.train_dataloader = train_dataloader
53 self.val_dataloader = val_dataloader
52 self.output_dir = output_dir 54 self.output_dir = output_dir
53 self.sample_image_size = sample_image_size 55 self.sample_image_size = sample_image_size
54 self.seed = seed if seed is not None else torch.random.seed() 56 self.seed = seed if seed is not None else torch.random.seed()
@@ -70,15 +72,16 @@ class CheckpointerBase:
70 ): 72 ):
71 samples_path = Path(self.output_dir).joinpath("samples") 73 samples_path = Path(self.output_dir).joinpath("samples")
72 74
73 train_data = self.datamodule.train_dataloader
74 val_data = self.datamodule.val_dataloader
75
76 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 75 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
77 76
78 grid_cols = min(self.sample_batch_size, 4) 77 grid_cols = min(self.sample_batch_size, 4)
79 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols 78 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
80 79
81 for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: 80 for pool, data, gen in [
81 ("stable", self.val_dataloader, generator),
82 ("val", self.val_dataloader, None),
83 ("train", self.train_dataloader, None)
84 ]:
82 all_samples = [] 85 all_samples = []
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 86 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
84 file_path.parent.mkdir(parents=True, exist_ok=True) 87 file_path.parent.mkdir(parents=True, exist_ok=True)