summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py467
1 files changed, 386 insertions, 81 deletions
diff --git a/train_ti.py b/train_ti.py
index 3a55f40..61195f6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,15 +1,29 @@
1import argparse 1import argparse
2import datetime
3import logging
4from functools import partial
5from pathlib import Path
6from contextlib import contextmanager, nullcontext
2 7
3import torch 8import torch
4import torch.utils.checkpoint 9import torch.utils.checkpoint
5 10
11from accelerate import Accelerator
6from accelerate.logging import get_logger 12from accelerate.logging import get_logger
7 13from accelerate.utils import LoggerType, set_seed
8from util import load_config 14from diffusers import AutoencoderKL, UNet2DConditionModel
9from data.csv import VlpnDataItem 15import matplotlib.pyplot as plt
10from training.common import train_setup 16from transformers import CLIPTextModel
11from training.modules.ti import train_ti 17from slugify import slugify
12from training.util import save_args 18
19from util import load_config, load_embeddings_from_dir
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from data.csv import VlpnDataModule, VlpnDataItem
22from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
23from training.optimization import get_scheduler
24from training.lr import LRFinder
25from training.util import CheckpointerBase, EMAModel, save_args
26from models.clip.tokenizer import MultiCLIPTokenizer
13 27
14logger = get_logger(__name__) 28logger = get_logger(__name__)
15 29
@@ -52,13 +66,13 @@ def parse_args():
52 help="The name of the current project.", 66 help="The name of the current project.",
53 ) 67 )
54 parser.add_argument( 68 parser.add_argument(
55 "--placeholder_token", 69 "--placeholder_tokens",
56 type=str, 70 type=str,
57 nargs='*', 71 nargs='*',
58 help="A token to use as a placeholder for the concept.", 72 help="A token to use as a placeholder for the concept.",
59 ) 73 )
60 parser.add_argument( 74 parser.add_argument(
61 "--initializer_token", 75 "--initializer_tokens",
62 type=str, 76 type=str,
63 nargs='*', 77 nargs='*',
64 help="A token to use as initializer word." 78 help="A token to use as initializer word."
@@ -439,29 +453,29 @@ def parse_args():
439 if args.project is None: 453 if args.project is None:
440 raise ValueError("You must specify --project") 454 raise ValueError("You must specify --project")
441 455
442 if isinstance(args.placeholder_token, str): 456 if isinstance(args.placeholder_tokens, str):
443 args.placeholder_token = [args.placeholder_token] 457 args.placeholder_tokens = [args.placeholder_tokens]
444 458
445 if len(args.placeholder_token) == 0: 459 if len(args.placeholder_tokens) == 0:
446 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 460 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
447 461
448 if isinstance(args.initializer_token, str): 462 if isinstance(args.initializer_tokens, str):
449 args.initializer_token = [args.initializer_token] * len(args.placeholder_token) 463 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
450 464
451 if len(args.initializer_token) == 0: 465 if len(args.initializer_tokens) == 0:
452 raise ValueError("You must specify --initializer_token") 466 raise ValueError("You must specify --initializer_tokens")
453 467
454 if len(args.placeholder_token) != len(args.initializer_token): 468 if len(args.placeholder_tokens) != len(args.initializer_tokens):
455 raise ValueError("--placeholder_token and --initializer_token must have the same number of items") 469 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
456 470
457 if args.num_vectors is None: 471 if args.num_vectors is None:
458 args.num_vectors = 1 472 args.num_vectors = 1
459 473
460 if isinstance(args.num_vectors, int): 474 if isinstance(args.num_vectors, int):
461 args.num_vectors = [args.num_vectors] * len(args.initializer_token) 475 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
462 476
463 if len(args.placeholder_token) != len(args.num_vectors): 477 if len(args.placeholder_tokens) != len(args.num_vectors):
464 raise ValueError("--placeholder_token and --num_vectors must have the same number of items") 478 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
465 479
466 if isinstance(args.collection, str): 480 if isinstance(args.collection, str):
467 args.collection = [args.collection] 481 args.collection = [args.collection]
@@ -475,13 +489,197 @@ def parse_args():
475 return args 489 return args
476 490
477 491
492class Checkpointer(CheckpointerBase):
493 def __init__(
494 self,
495 weight_dtype,
496 accelerator: Accelerator,
497 vae: AutoencoderKL,
498 unet: UNet2DConditionModel,
499 tokenizer: MultiCLIPTokenizer,
500 text_encoder: CLIPTextModel,
501 ema_embeddings: EMAModel,
502 scheduler,
503 placeholder_tokens,
504 placeholder_token_ids,
505 *args,
506 **kwargs
507 ):
508 super().__init__(*args, **kwargs)
509
510 self.weight_dtype = weight_dtype
511 self.accelerator = accelerator
512 self.vae = vae
513 self.unet = unet
514 self.tokenizer = tokenizer
515 self.text_encoder = text_encoder
516 self.ema_embeddings = ema_embeddings
517 self.scheduler = scheduler
518 self.placeholder_tokens = placeholder_tokens
519 self.placeholder_token_ids = placeholder_token_ids
520
521 @torch.no_grad()
522 def checkpoint(self, step, postfix):
523 print("Saving checkpoint for step %d..." % step)
524
525 checkpoints_path = self.output_dir.joinpath("checkpoints")
526 checkpoints_path.mkdir(parents=True, exist_ok=True)
527
528 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
529
530 ema_context = self.ema_embeddings.apply_temporary(
531 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
532
533 with ema_context:
534 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
535 text_encoder.text_model.embeddings.save_embed(
536 ids,
537 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
538 )
539
540 del text_encoder
541
542 @torch.no_grad()
543 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
544 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
545
546 ema_context = self.ema_embeddings.apply_temporary(
547 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
548
549 with ema_context:
550 orig_dtype = text_encoder.dtype
551 text_encoder.to(dtype=self.weight_dtype)
552
553 pipeline = VlpnStableDiffusion(
554 text_encoder=text_encoder,
555 vae=self.vae,
556 unet=self.unet,
557 tokenizer=self.tokenizer,
558 scheduler=self.scheduler,
559 ).to(self.accelerator.device)
560 pipeline.set_progress_bar_config(dynamic_ncols=True)
561
562 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
563
564 text_encoder.to(dtype=orig_dtype)
565
566 del text_encoder
567 del pipeline
568
569 if torch.cuda.is_available():
570 torch.cuda.empty_cache()
571
572
478def main(): 573def main():
479 args = parse_args() 574 args = parse_args()
480 575
481 def data_filter(item: VlpnDataItem): 576 global_step_offset = args.global_step
577 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
578 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
579 basepath.mkdir(parents=True, exist_ok=True)
580
581 accelerator = Accelerator(
582 log_with=LoggerType.TENSORBOARD,
583 logging_dir=f"{basepath}",
584 gradient_accumulation_steps=args.gradient_accumulation_steps,
585 mixed_precision=args.mixed_precision
586 )
587
588 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
589
590 args.seed = args.seed or (torch.random.seed() >> 32)
591 set_seed(args.seed)
592
593 save_args(basepath, args)
594
595 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
596 args.pretrained_model_name_or_path)
597
598 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
599 tokenizer.set_dropout(args.vector_dropout)
600
601 vae.enable_slicing()
602 vae.set_use_memory_efficient_attention_xformers(True)
603 unet.set_use_memory_efficient_attention_xformers(True)
604
605 if args.gradient_checkpointing:
606 unet.enable_gradient_checkpointing()
607 text_encoder.gradient_checkpointing_enable()
608
609 if args.embeddings_dir is not None:
610 embeddings_dir = Path(args.embeddings_dir)
611 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
612 raise ValueError("--embeddings_dir must point to an existing directory")
613
614 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
615 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
616
617 placeholder_token_ids = add_placeholder_tokens(
618 tokenizer=tokenizer,
619 embeddings=embeddings,
620 placeholder_tokens=args.placeholder_tokens,
621 initializer_tokens=args.initializer_tokens,
622 num_vectors=args.num_vectors
623 )
624
625 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}")
626
627 if args.use_ema:
628 ema_embeddings = EMAModel(
629 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
630 inv_gamma=args.ema_inv_gamma,
631 power=args.ema_power,
632 max_value=args.ema_max_decay,
633 )
634 else:
635 ema_embeddings = None
636
637 vae.requires_grad_(False)
638 unet.requires_grad_(False)
639
640 text_encoder.text_model.encoder.requires_grad_(False)
641 text_encoder.text_model.final_layer_norm.requires_grad_(False)
642 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
643 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
644
645 if args.scale_lr:
646 args.learning_rate = (
647 args.learning_rate * args.gradient_accumulation_steps *
648 args.train_batch_size * accelerator.num_processes
649 )
650
651 if args.find_lr:
652 args.learning_rate = 1e-5
653
654 if args.use_8bit_adam:
655 try:
656 import bitsandbytes as bnb
657 except ImportError:
658 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
659
660 optimizer_class = bnb.optim.AdamW8bit
661 else:
662 optimizer_class = torch.optim.AdamW
663
664 optimizer = optimizer_class(
665 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
666 lr=args.learning_rate,
667 betas=(args.adam_beta1, args.adam_beta2),
668 weight_decay=args.adam_weight_decay,
669 eps=args.adam_epsilon,
670 amsgrad=args.adam_amsgrad,
671 )
672
673 weight_dtype = torch.float32
674 if args.mixed_precision == "fp16":
675 weight_dtype = torch.float16
676 elif args.mixed_precision == "bf16":
677 weight_dtype = torch.bfloat16
678
679 def keyword_filter(item: VlpnDataItem):
482 cond1 = any( 680 cond1 = any(
483 keyword in part 681 keyword in part
484 for keyword in args.placeholder_token 682 for keyword in args.placeholder_tokens
485 for part in item.prompt 683 for part in item.prompt
486 ) 684 )
487 cond3 = args.collection is None or args.collection in item.collection 685 cond3 = args.collection is None or args.collection in item.collection
@@ -491,78 +689,185 @@ def main():
491 ) 689 )
492 return cond1 and cond3 and cond4 690 return cond1 and cond3 and cond4
493 691
494 setup = train_setup( 692 datamodule = VlpnDataModule(
495 output_dir=args.output_dir,
496 project=args.project,
497 pretrained_model_name_or_path=args.pretrained_model_name_or_path,
498 learning_rate=args.learning_rate,
499 data_file=args.train_data_file, 693 data_file=args.train_data_file,
500 gradient_accumulation_steps=args.gradient_accumulation_steps, 694 batch_size=args.train_batch_size,
501 mixed_precision=args.mixed_precision, 695 tokenizer=tokenizer,
502 seed=args.seed, 696 class_subdir=args.class_image_dir,
503 vector_shuffle=args.vector_shuffle,
504 vector_dropout=args.vector_dropout,
505 gradient_checkpointing=args.gradient_checkpointing,
506 embeddings_dir=args.embeddings_dir,
507 placeholder_token=args.placeholder_token,
508 initializer_token=args.initializer_token,
509 num_vectors=args.num_vectors,
510 scale_lr=args.scale_lr,
511 use_8bit_adam=args.use_8bit_adam,
512 train_batch_size=args.train_batch_size,
513 class_image_dir=args.class_image_dir,
514 num_class_images=args.num_class_images, 697 num_class_images=args.num_class_images,
515 resolution=args.resolution, 698 size=args.resolution,
516 num_buckets=args.num_buckets, 699 num_buckets=args.num_buckets,
517 progressive_buckets=args.progressive_buckets, 700 progressive_buckets=args.progressive_buckets,
518 bucket_step_size=args.bucket_step_size, 701 bucket_step_size=args.bucket_step_size,
519 bucket_max_pixels=args.bucket_max_pixels, 702 bucket_max_pixels=args.bucket_max_pixels,
520 tag_dropout=args.tag_dropout, 703 dropout=args.tag_dropout,
521 tag_shuffle=not args.no_tag_shuffle, 704 shuffle=not args.no_tag_shuffle,
522 data_template=args.train_data_template, 705 template_key=args.train_data_template,
523 valid_set_size=args.valid_set_size, 706 valid_set_size=args.valid_set_size,
524 valid_set_repeat=args.valid_set_repeat, 707 valid_set_repeat=args.valid_set_repeat,
525 data_filter=data_filter, 708 num_workers=args.dataloader_num_workers,
526 sample_image_size=args.sample_image_size, 709 seed=args.seed,
527 sample_batch_size=args.sample_batch_size, 710 filter=keyword_filter,
528 sample_steps=args.sample_steps, 711 dtype=weight_dtype
712 )
713 datamodule.setup()
714
715 train_dataloader = datamodule.train_dataloader
716 val_dataloader = datamodule.val_dataloader
717
718 if args.num_class_images != 0:
719 generate_class_images(
720 accelerator,
721 text_encoder,
722 vae,
723 unet,
724 tokenizer,
725 sample_scheduler,
726 datamodule.data_train,
727 args.sample_batch_size,
728 args.sample_image_size,
729 args.sample_steps
730 )
731
732 if args.find_lr:
733 lr_scheduler = None
734 else:
735 lr_scheduler = get_scheduler(
736 args.lr_scheduler,
737 optimizer=optimizer,
738 num_training_steps_per_epoch=len(train_dataloader),
739 gradient_accumulation_steps=args.gradient_accumulation_steps,
740 min_lr=args.lr_min_lr,
741 warmup_func=args.lr_warmup_func,
742 annealing_func=args.lr_annealing_func,
743 warmup_exp=args.lr_warmup_exp,
744 annealing_exp=args.lr_annealing_exp,
745 cycles=args.lr_cycles,
746 train_epochs=args.num_train_epochs,
747 warmup_epochs=args.lr_warmup_epochs,
748 )
749
750 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
751 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
529 ) 752 )
530 753
531 save_args(setup.output_dir, args) 754 vae.to(accelerator.device, dtype=weight_dtype)
755 unet.to(accelerator.device, dtype=weight_dtype)
532 756
533 train_ti( 757 if args.use_ema:
534 setup=setup, 758 ema_embeddings.to(accelerator.device)
535 num_train_epochs=args.num_train_epochs, 759
536 num_class_images=args.num_class_images, 760 if args.gradient_checkpointing:
537 prior_loss_weight=args.prior_loss_weight, 761 unet.train()
538 use_ema=args.use_ema, 762 else:
539 ema_inv_gamma=args.ema_inv_gamma, 763 unet.eval()
540 ema_power=args.ema_power, 764
541 ema_max_decay=args.ema_max_decay, 765 @contextmanager
542 adam_beta1=args.adam_beta1, 766 def on_train(epoch: int):
543 adam_beta2=args.adam_beta2, 767 try:
544 adam_weight_decay=args.adam_weight_decay, 768 tokenizer.train()
545 adam_epsilon=args.adam_epsilon, 769 yield
546 adam_amsgrad=args.adam_amsgrad, 770 finally:
547 lr_scheduler=args.lr_scheduler, 771 pass
548 lr_min_lr=args.lr_min_lr, 772
549 lr_warmup_func=args.lr_warmup_func, 773 @contextmanager
550 lr_annealing_func=args.lr_annealing_func, 774 def on_eval():
551 lr_warmup_exp=args.lr_warmup_exp, 775 try:
552 lr_annealing_exp=args.lr_annealing_exp, 776 tokenizer.eval()
553 lr_cycles=args.lr_cycles, 777
554 lr_warmup_epochs=args.lr_warmup_epochs, 778 ema_context = ema_embeddings.apply_temporary(
555 emb_decay_target=args.emb_decay_target, 779 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext()
556 emb_decay_factor=args.emb_decay_factor, 780
557 emb_decay_start=args.emb_decay_start, 781 with ema_context:
782 yield
783 finally:
784 pass
785
786 @torch.no_grad()
787 def on_after_optimize(lr: float):
788 text_encoder.text_model.embeddings.normalize(
789 args.emb_decay_target,
790 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start))))
791 )
792
793 if args.use_ema:
794 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
795
796 def on_log():
797 if args.use_ema:
798 return {"ema_decay": ema_embeddings.decay}
799 return {}
800
801 loss_step_ = partial(
802 loss_step,
803 vae,
804 noise_scheduler,
805 unet,
806 text_encoder,
807 args.num_class_images != 0,
808 args.prior_loss_weight,
809 args.seed,
810 )
811
812 checkpointer = Checkpointer(
813 weight_dtype=weight_dtype,
814 train_dataloader=train_dataloader,
815 val_dataloader=val_dataloader,
816 accelerator=accelerator,
817 vae=vae,
818 unet=unet,
819 tokenizer=tokenizer,
820 text_encoder=text_encoder,
821 ema_embeddings=ema_embeddings,
822 scheduler=sample_scheduler,
823 placeholder_tokens=args.placeholder_tokens,
824 placeholder_token_ids=placeholder_token_ids,
825 output_dir=basepath,
558 sample_image_size=args.sample_image_size, 826 sample_image_size=args.sample_image_size,
559 sample_batch_size=args.sample_batch_size, 827 sample_batch_size=args.sample_batch_size,
560 sample_batches=args.sample_batches, 828 sample_batches=args.sample_batches,
561 sample_frequency=args.sample_frequency, 829 seed=args.seed
562 sample_steps=args.sample_steps, 830 )
563 checkpoint_frequency=args.checkpoint_frequency, 831
564 global_step_offset=args.global_step, 832 if accelerator.is_main_process:
565 ) 833 accelerator.init_trackers("textual_inversion")
834
835 if args.find_lr:
836 lr_finder = LRFinder(
837 accelerator=accelerator,
838 optimizer=optimizer,
839 model=text_encoder,
840 train_dataloader=train_dataloader,
841 val_dataloader=val_dataloader,
842 loss_step=loss_step_,
843 on_train=on_train,
844 on_eval=on_eval,
845 on_after_optimize=on_after_optimize,
846 )
847 lr_finder.run(num_epochs=100, end_lr=1e3)
848
849 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
850 plt.close()
851 else:
852 train_loop(
853 accelerator=accelerator,
854 optimizer=optimizer,
855 lr_scheduler=lr_scheduler,
856 model=text_encoder,
857 checkpointer=checkpointer,
858 train_dataloader=train_dataloader,
859 val_dataloader=val_dataloader,
860 loss_step=loss_step_,
861 sample_frequency=args.sample_frequency,
862 sample_steps=args.sample_steps,
863 checkpoint_frequency=args.checkpoint_frequency,
864 global_step_offset=global_step_offset,
865 num_epochs=args.num_train_epochs,
866 on_log=on_log,
867 on_train=on_train,
868 on_after_optimize=on_after_optimize,
869 on_eval=on_eval
870 )
566 871
567 872
568if __name__ == "__main__": 873if __name__ == "__main__":