summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py479
1 files changed, 70 insertions, 409 deletions
diff --git a/train_ti.py b/train_ti.py
index 3f4e739..3a55f40 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,31 +1,15 @@
1import argparse 1import argparse
2import math
3import datetime
4import logging
5from functools import partial
6from pathlib import Path
7from contextlib import contextmanager, nullcontext
8 2
9import torch 3import torch
10import torch.utils.checkpoint 4import torch.utils.checkpoint
11 5
12from accelerate import Accelerator
13from accelerate.logging import get_logger 6from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 7
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 8from util import load_config
16import matplotlib.pyplot as plt 9from data.csv import VlpnDataItem
17from tqdm.auto import tqdm 10from training.common import train_setup
18from transformers import CLIPTextModel 11from training.modules.ti import train_ti
19from slugify import slugify 12from training.util import save_args
20
21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import VlpnDataModule, VlpnDataItem
24from training.common import loss_step, train_loop, generate_class_images, get_scheduler
25from training.lr import LRFinder
26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
27from models.clip.embeddings import patch_managed_embeddings
28from models.clip.tokenizer import MultiCLIPTokenizer
29 13
30logger = get_logger(__name__) 14logger = get_logger(__name__)
31 15
@@ -271,7 +255,7 @@ def parse_args():
271 parser.add_argument( 255 parser.add_argument(
272 "--lr_min_lr", 256 "--lr_min_lr",
273 type=float, 257 type=float,
274 default=None, 258 default=0.04,
275 help="Minimum learning rate in the lr scheduler." 259 help="Minimum learning rate in the lr scheduler."
276 ) 260 )
277 parser.add_argument( 261 parser.add_argument(
@@ -401,19 +385,19 @@ def parse_args():
401 help="The weight of prior preservation loss." 385 help="The weight of prior preservation loss."
402 ) 386 )
403 parser.add_argument( 387 parser.add_argument(
404 "--decay_target", 388 "--emb_decay_target",
405 default=None, 389 default=0.4,
406 type=float, 390 type=float,
407 help="Embedding decay target." 391 help="Embedding decay target."
408 ) 392 )
409 parser.add_argument( 393 parser.add_argument(
410 "--decay_factor", 394 "--emb_decay_factor",
411 default=1, 395 default=1,
412 type=float, 396 type=float,
413 help="Embedding decay factor." 397 help="Embedding decay factor."
414 ) 398 )
415 parser.add_argument( 399 parser.add_argument(
416 "--decay_start", 400 "--emb_decay_start",
417 default=1e-4, 401 default=1e-4,
418 type=float, 402 type=float,
419 help="Embedding decay start offset." 403 help="Embedding decay start offset."
@@ -491,213 +475,10 @@ def parse_args():
491 return args 475 return args
492 476
493 477
494class Checkpointer(CheckpointerBase):
495 def __init__(
496 self,
497 weight_dtype,
498 accelerator: Accelerator,
499 vae: AutoencoderKL,
500 unet: UNet2DConditionModel,
501 tokenizer: MultiCLIPTokenizer,
502 text_encoder: CLIPTextModel,
503 ema_embeddings: EMAModel,
504 scheduler,
505 placeholder_token,
506 new_ids,
507 *args,
508 **kwargs
509 ):
510 super().__init__(*args, **kwargs)
511
512 self.weight_dtype = weight_dtype
513 self.accelerator = accelerator
514 self.vae = vae
515 self.unet = unet
516 self.tokenizer = tokenizer
517 self.text_encoder = text_encoder
518 self.ema_embeddings = ema_embeddings
519 self.scheduler = scheduler
520 self.placeholder_token = placeholder_token
521 self.new_ids = new_ids
522
523 @torch.no_grad()
524 def checkpoint(self, step, postfix):
525 print("Saving checkpoint for step %d..." % step)
526
527 checkpoints_path = self.output_dir.joinpath("checkpoints")
528 checkpoints_path.mkdir(parents=True, exist_ok=True)
529
530 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
531
532 ema_context = self.ema_embeddings.apply_temporary(
533 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
534
535 with ema_context:
536 for (token, ids) in zip(self.placeholder_token, self.new_ids):
537 text_encoder.text_model.embeddings.save_embed(
538 ids,
539 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
540 )
541
542 del text_encoder
543
544 @torch.no_grad()
545 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
546 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
547
548 ema_context = self.ema_embeddings.apply_temporary(
549 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
550
551 with ema_context:
552 orig_dtype = text_encoder.dtype
553 text_encoder.to(dtype=self.weight_dtype)
554
555 pipeline = VlpnStableDiffusion(
556 text_encoder=text_encoder,
557 vae=self.vae,
558 unet=self.unet,
559 tokenizer=self.tokenizer,
560 scheduler=self.scheduler,
561 ).to(self.accelerator.device)
562 pipeline.set_progress_bar_config(dynamic_ncols=True)
563
564 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
565
566 text_encoder.to(dtype=orig_dtype)
567
568 del text_encoder
569 del pipeline
570
571 if torch.cuda.is_available():
572 torch.cuda.empty_cache()
573
574
575def main(): 478def main():
576 args = parse_args() 479 args = parse_args()
577 480
578 global_step_offset = args.global_step 481 def data_filter(item: VlpnDataItem):
579 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
580 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
581 basepath.mkdir(parents=True, exist_ok=True)
582
583 accelerator = Accelerator(
584 log_with=LoggerType.TENSORBOARD,
585 logging_dir=f"{basepath}",
586 gradient_accumulation_steps=args.gradient_accumulation_steps,
587 mixed_precision=args.mixed_precision
588 )
589
590 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
591
592 args.seed = args.seed or (torch.random.seed() >> 32)
593 set_seed(args.seed)
594
595 save_args(basepath, args)
596
597 # Load the tokenizer and add the placeholder token as a additional special token
598 if args.tokenizer_name:
599 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
600 elif args.pretrained_model_name_or_path:
601 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
602 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
603 tokenizer.set_dropout(args.vector_dropout)
604
605 # Load models and create wrapper for stable diffusion
606 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
607 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
608 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
609 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
610 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
611 args.pretrained_model_name_or_path, subfolder='scheduler')
612
613 vae.enable_slicing()
614 vae.set_use_memory_efficient_attention_xformers(True)
615 unet.set_use_memory_efficient_attention_xformers(True)
616
617 if args.gradient_checkpointing:
618 unet.enable_gradient_checkpointing()
619 text_encoder.gradient_checkpointing_enable()
620
621 embeddings = patch_managed_embeddings(text_encoder)
622 ema_embeddings = None
623
624 if args.embeddings_dir is not None:
625 embeddings_dir = Path(args.embeddings_dir)
626 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
627 raise ValueError("--embeddings_dir must point to an existing directory")
628
629 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
630 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
631
632 # Convert the initializer_token, placeholder_token to ids
633 initializer_token_ids = [
634 tokenizer.encode(token, add_special_tokens=False)
635 for token in args.initializer_token
636 ]
637
638 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
639 embeddings.resize(len(tokenizer))
640
641 for (new_id, init_ids) in zip(new_ids, initializer_token_ids):
642 embeddings.add_embed(new_id, init_ids)
643
644 init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)]
645
646 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
647
648 if args.use_ema:
649 ema_embeddings = EMAModel(
650 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
651 inv_gamma=args.ema_inv_gamma,
652 power=args.ema_power,
653 max_value=args.ema_max_decay,
654 )
655
656 vae.requires_grad_(False)
657 unet.requires_grad_(False)
658
659 text_encoder.text_model.encoder.requires_grad_(False)
660 text_encoder.text_model.final_layer_norm.requires_grad_(False)
661 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
662 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
663
664 if args.scale_lr:
665 args.learning_rate = (
666 args.learning_rate * args.gradient_accumulation_steps *
667 args.train_batch_size * accelerator.num_processes
668 )
669
670 if args.find_lr:
671 args.learning_rate = 1e-5
672
673 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
674 if args.use_8bit_adam:
675 try:
676 import bitsandbytes as bnb
677 except ImportError:
678 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
679
680 optimizer_class = bnb.optim.AdamW8bit
681 else:
682 optimizer_class = torch.optim.AdamW
683
684 # Initialize the optimizer
685 optimizer = optimizer_class(
686 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings
687 lr=args.learning_rate,
688 betas=(args.adam_beta1, args.adam_beta2),
689 weight_decay=args.adam_weight_decay,
690 eps=args.adam_epsilon,
691 amsgrad=args.adam_amsgrad,
692 )
693
694 weight_dtype = torch.float32
695 if args.mixed_precision == "fp16":
696 weight_dtype = torch.float16
697 elif args.mixed_precision == "bf16":
698 weight_dtype = torch.bfloat16
699
700 def keyword_filter(item: VlpnDataItem):
701 cond1 = any( 482 cond1 = any(
702 keyword in part 483 keyword in part
703 for keyword in args.placeholder_token 484 for keyword in args.placeholder_token
@@ -710,198 +491,78 @@ def main():
710 ) 491 )
711 return cond1 and cond3 and cond4 492 return cond1 and cond3 and cond4
712 493
713 datamodule = VlpnDataModule( 494 setup = train_setup(
495 output_dir=args.output_dir,
496 project=args.project,
497 pretrained_model_name_or_path=args.pretrained_model_name_or_path,
498 learning_rate=args.learning_rate,
714 data_file=args.train_data_file, 499 data_file=args.train_data_file,
715 batch_size=args.train_batch_size, 500 gradient_accumulation_steps=args.gradient_accumulation_steps,
716 tokenizer=tokenizer, 501 mixed_precision=args.mixed_precision,
717 class_subdir=args.class_image_dir, 502 seed=args.seed,
503 vector_shuffle=args.vector_shuffle,
504 vector_dropout=args.vector_dropout,
505 gradient_checkpointing=args.gradient_checkpointing,
506 embeddings_dir=args.embeddings_dir,
507 placeholder_token=args.placeholder_token,
508 initializer_token=args.initializer_token,
509 num_vectors=args.num_vectors,
510 scale_lr=args.scale_lr,
511 use_8bit_adam=args.use_8bit_adam,
512 train_batch_size=args.train_batch_size,
513 class_image_dir=args.class_image_dir,
718 num_class_images=args.num_class_images, 514 num_class_images=args.num_class_images,
719 size=args.resolution, 515 resolution=args.resolution,
720 num_buckets=args.num_buckets, 516 num_buckets=args.num_buckets,
721 progressive_buckets=args.progressive_buckets, 517 progressive_buckets=args.progressive_buckets,
722 bucket_step_size=args.bucket_step_size, 518 bucket_step_size=args.bucket_step_size,
723 bucket_max_pixels=args.bucket_max_pixels, 519 bucket_max_pixels=args.bucket_max_pixels,
724 dropout=args.tag_dropout, 520 tag_dropout=args.tag_dropout,
725 shuffle=not args.no_tag_shuffle, 521 tag_shuffle=not args.no_tag_shuffle,
726 template_key=args.train_data_template, 522 data_template=args.train_data_template,
727 valid_set_size=args.valid_set_size, 523 valid_set_size=args.valid_set_size,
728 valid_set_repeat=args.valid_set_repeat, 524 valid_set_repeat=args.valid_set_repeat,
729 num_workers=args.dataloader_num_workers, 525 data_filter=data_filter,
730 seed=args.seed, 526 sample_image_size=args.sample_image_size,
731 filter=keyword_filter, 527 sample_batch_size=args.sample_batch_size,
732 dtype=weight_dtype 528 sample_steps=args.sample_steps,
733 )
734 datamodule.setup()
735
736 train_dataloader = datamodule.train_dataloader
737 val_dataloader = datamodule.val_dataloader
738
739 if args.num_class_images != 0:
740 generate_class_images(
741 accelerator,
742 text_encoder,
743 vae,
744 unet,
745 tokenizer,
746 checkpoint_scheduler,
747 datamodule.data_train,
748 args.sample_batch_size,
749 args.sample_image_size,
750 args.sample_steps
751 )
752
753 if args.find_lr:
754 lr_scheduler = None
755 else:
756 lr_scheduler = get_scheduler(
757 args.lr_scheduler,
758 optimizer=optimizer,
759 min_lr=args.lr_min_lr,
760 lr=args.learning_rate,
761 warmup_func=args.lr_warmup_func,
762 annealing_func=args.lr_annealing_func,
763 warmup_exp=args.lr_warmup_exp,
764 annealing_exp=args.lr_annealing_exp,
765 cycles=args.lr_cycles,
766 train_epochs=args.num_train_epochs,
767 warmup_epochs=args.lr_warmup_epochs,
768 num_training_steps_per_epoch=len(train_dataloader),
769 gradient_accumulation_steps=args.gradient_accumulation_steps
770 )
771
772 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
773 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
774 ) 529 )
775 530
776 # Move vae and unet to device 531 save_args(setup.output_dir, args)
777 vae.to(accelerator.device, dtype=weight_dtype)
778 unet.to(accelerator.device, dtype=weight_dtype)
779
780 if args.use_ema:
781 ema_embeddings.to(accelerator.device)
782 532
783 # Keep vae and unet in eval mode as we don't train these 533 train_ti(
784 vae.eval() 534 setup=setup,
785 535 num_train_epochs=args.num_train_epochs,
786 if args.gradient_checkpointing: 536 num_class_images=args.num_class_images,
787 unet.train() 537 prior_loss_weight=args.prior_loss_weight,
788 else: 538 use_ema=args.use_ema,
789 unet.eval() 539 ema_inv_gamma=args.ema_inv_gamma,
790 540 ema_power=args.ema_power,
791 @contextmanager 541 ema_max_decay=args.ema_max_decay,
792 def on_train(): 542 adam_beta1=args.adam_beta1,
793 try: 543 adam_beta2=args.adam_beta2,
794 tokenizer.train() 544 adam_weight_decay=args.adam_weight_decay,
795 yield 545 adam_epsilon=args.adam_epsilon,
796 finally: 546 adam_amsgrad=args.adam_amsgrad,
797 pass 547 lr_scheduler=args.lr_scheduler,
798 548 lr_min_lr=args.lr_min_lr,
799 @contextmanager 549 lr_warmup_func=args.lr_warmup_func,
800 def on_eval(): 550 lr_annealing_func=args.lr_annealing_func,
801 try: 551 lr_warmup_exp=args.lr_warmup_exp,
802 tokenizer.eval() 552 lr_annealing_exp=args.lr_annealing_exp,
803 553 lr_cycles=args.lr_cycles,
804 ema_context = ema_embeddings.apply_temporary( 554 lr_warmup_epochs=args.lr_warmup_epochs,
805 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() 555 emb_decay_target=args.emb_decay_target,
806 556 emb_decay_factor=args.emb_decay_factor,
807 with ema_context: 557 emb_decay_start=args.emb_decay_start,
808 yield
809 finally:
810 pass
811
812 @torch.no_grad()
813 def on_after_optimize(lr: float):
814 text_encoder.text_model.embeddings.normalize(
815 args.decay_target,
816 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
817 )
818
819 if args.use_ema:
820 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
821
822 def on_log():
823 if args.use_ema:
824 return {"ema_decay": ema_embeddings.decay}
825 return {}
826
827 loss_step_ = partial(
828 loss_step,
829 vae,
830 noise_scheduler,
831 unet,
832 text_encoder,
833 args.num_class_images != 0,
834 args.prior_loss_weight,
835 args.seed,
836 )
837
838 checkpointer = Checkpointer(
839 weight_dtype=weight_dtype,
840 datamodule=datamodule,
841 accelerator=accelerator,
842 vae=vae,
843 unet=unet,
844 tokenizer=tokenizer,
845 text_encoder=text_encoder,
846 ema_embeddings=ema_embeddings,
847 scheduler=checkpoint_scheduler,
848 placeholder_token=args.placeholder_token,
849 new_ids=new_ids,
850 output_dir=basepath,
851 sample_image_size=args.sample_image_size, 558 sample_image_size=args.sample_image_size,
852 sample_batch_size=args.sample_batch_size, 559 sample_batch_size=args.sample_batch_size,
853 sample_batches=args.sample_batches, 560 sample_batches=args.sample_batches,
854 seed=args.seed 561 sample_frequency=args.sample_frequency,
855 ) 562 sample_steps=args.sample_steps,
856 563 checkpoint_frequency=args.checkpoint_frequency,
857 if accelerator.is_main_process: 564 global_step_offset=args.global_step,
858 config = vars(args).copy() 565 )
859 config["initializer_token"] = " ".join(config["initializer_token"])
860 config["placeholder_token"] = " ".join(config["placeholder_token"])
861 config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]])
862 if config["collection"] is not None:
863 config["collection"] = " ".join(config["collection"])
864 if config["exclude_collections"] is not None:
865 config["exclude_collections"] = " ".join(config["exclude_collections"])
866 accelerator.init_trackers("textual_inversion", config=config)
867
868 if args.find_lr:
869 lr_finder = LRFinder(
870 accelerator=accelerator,
871 optimizer=optimizer,
872 model=text_encoder,
873 train_dataloader=train_dataloader,
874 val_dataloader=val_dataloader,
875 loss_step=loss_step_,
876 on_train=on_train,
877 on_eval=on_eval,
878 on_after_optimize=on_after_optimize,
879 )
880 lr_finder.run(num_epochs=100, end_lr=1e3)
881
882 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
883 plt.close()
884 else:
885 train_loop(
886 accelerator=accelerator,
887 optimizer=optimizer,
888 lr_scheduler=lr_scheduler,
889 model=text_encoder,
890 checkpointer=checkpointer,
891 train_dataloader=train_dataloader,
892 val_dataloader=val_dataloader,
893 loss_step=loss_step_,
894 sample_frequency=args.sample_frequency,
895 sample_steps=args.sample_steps,
896 checkpoint_frequency=args.checkpoint_frequency,
897 global_step_offset=global_step_offset,
898 gradient_accumulation_steps=args.gradient_accumulation_steps,
899 num_epochs=args.num_train_epochs,
900 on_log=on_log,
901 on_train=on_train,
902 on_after_optimize=on_after_optimize,
903 on_eval=on_eval
904 )
905 566
906 567
907if __name__ == "__main__": 568if __name__ == "__main__":