diff options
-rw-r--r-- | train_dreambooth.py | 205 | ||||
-rw-r--r-- | train_ti.py | 9 |
2 files changed, 84 insertions, 130 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2145e2b..a1802a0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -12,20 +12,18 @@ import torch.utils.checkpoint | |||
12 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, UNet2DConditionModel |
16 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
17 | from diffusers.training_utils import EMAModel | ||
18 | from transformers import CLIPTextModel | 17 | from transformers import CLIPTextModel |
19 | from slugify import slugify | 18 | from slugify import slugify |
20 | 19 | ||
21 | from util import load_config, load_embeddings_from_dir | 20 | from util import load_config, load_embeddings_from_dir |
22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
23 | from data.csv import VlpnDataModule, VlpnDataItem | 22 | from data.csv import VlpnDataModule, VlpnDataItem |
24 | from training.common import loss_step, train_loop, generate_class_images | 23 | from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
25 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
26 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
27 | from training.util import CheckpointerBase, save_args | 26 | from training.util import CheckpointerBase, EMAModel, save_args |
28 | from models.clip.embeddings import patch_managed_embeddings | ||
29 | from models.clip.tokenizer import MultiCLIPTokenizer | 27 | from models.clip.tokenizer import MultiCLIPTokenizer |
30 | 28 | ||
31 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
@@ -69,7 +67,7 @@ def parse_args(): | |||
69 | help="The name of the current project.", | 67 | help="The name of the current project.", |
70 | ) | 68 | ) |
71 | parser.add_argument( | 69 | parser.add_argument( |
72 | "--placeholder_token", | 70 | "--placeholder_tokens", |
73 | type=str, | 71 | type=str, |
74 | nargs='*', | 72 | nargs='*', |
75 | default=[], | 73 | default=[], |
@@ -446,20 +444,20 @@ def parse_args(): | |||
446 | if args.project is None: | 444 | if args.project is None: |
447 | raise ValueError("You must specify --project") | 445 | raise ValueError("You must specify --project") |
448 | 446 | ||
449 | if isinstance(args.placeholder_token, str): | 447 | if isinstance(args.placeholder_tokens, str): |
450 | args.placeholder_token = [args.placeholder_token] | 448 | args.placeholder_tokens = [args.placeholder_tokens] |
451 | 449 | ||
452 | if len(args.placeholder_token) == 0: | 450 | if len(args.placeholder_tokens) == 0: |
453 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 451 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] |
454 | 452 | ||
455 | if isinstance(args.initializer_token, str): | 453 | if isinstance(args.initializer_token, str): |
456 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | 454 | args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) |
457 | 455 | ||
458 | if len(args.initializer_token) == 0: | 456 | if len(args.initializer_token) == 0: |
459 | raise ValueError("You must specify --initializer_token") | 457 | raise ValueError("You must specify --initializer_token") |
460 | 458 | ||
461 | if len(args.placeholder_token) != len(args.initializer_token): | 459 | if len(args.placeholder_tokens) != len(args.initializer_token): |
462 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | 460 | raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") |
463 | 461 | ||
464 | if args.num_vectors is None: | 462 | if args.num_vectors is None: |
465 | args.num_vectors = 1 | 463 | args.num_vectors = 1 |
@@ -467,8 +465,8 @@ def parse_args(): | |||
467 | if isinstance(args.num_vectors, int): | 465 | if isinstance(args.num_vectors, int): |
468 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 466 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) |
469 | 467 | ||
470 | if len(args.placeholder_token) != len(args.num_vectors): | 468 | if len(args.placeholder_tokens) != len(args.num_vectors): |
471 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | 469 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
472 | 470 | ||
473 | if isinstance(args.collection, str): | 471 | if isinstance(args.collection, str): |
474 | args.collection = [args.collection] | 472 | args.collection = [args.collection] |
@@ -485,33 +483,18 @@ def parse_args(): | |||
485 | class Checkpointer(CheckpointerBase): | 483 | class Checkpointer(CheckpointerBase): |
486 | def __init__( | 484 | def __init__( |
487 | self, | 485 | self, |
488 | weight_dtype, | 486 | weight_dtype: torch.dtype, |
489 | datamodule, | 487 | accelerator: Accelerator, |
490 | accelerator, | 488 | vae: AutoencoderKL, |
491 | vae, | 489 | unet: UNet2DConditionModel, |
492 | unet, | 490 | ema_unet: EMAModel, |
493 | ema_unet, | 491 | tokenizer: MultiCLIPTokenizer, |
494 | tokenizer, | 492 | text_encoder: CLIPTextModel, |
495 | text_encoder, | ||
496 | scheduler, | 493 | scheduler, |
497 | output_dir: Path, | 494 | *args, |
498 | placeholder_token, | 495 | **kwargs |
499 | placeholder_token_id, | ||
500 | sample_image_size, | ||
501 | sample_batches, | ||
502 | sample_batch_size, | ||
503 | seed, | ||
504 | ): | 496 | ): |
505 | super().__init__( | 497 | super().__init__(*args, **kwargs) |
506 | datamodule=datamodule, | ||
507 | output_dir=output_dir, | ||
508 | placeholder_token=placeholder_token, | ||
509 | placeholder_token_id=placeholder_token_id, | ||
510 | sample_image_size=sample_image_size, | ||
511 | seed=seed or torch.random.seed(), | ||
512 | sample_batches=sample_batches, | ||
513 | sample_batch_size=sample_batch_size | ||
514 | ) | ||
515 | 498 | ||
516 | self.weight_dtype = weight_dtype | 499 | self.weight_dtype = weight_dtype |
517 | self.accelerator = accelerator | 500 | self.accelerator = accelerator |
@@ -606,28 +589,19 @@ def main(): | |||
606 | 589 | ||
607 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 590 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
608 | 591 | ||
609 | args.seed = args.seed or (torch.random.seed() >> 32) | 592 | if args.seed is None: |
593 | args.seed = torch.random.seed() >> 32 | ||
594 | |||
610 | set_seed(args.seed) | 595 | set_seed(args.seed) |
611 | 596 | ||
612 | save_args(basepath, args) | 597 | save_args(basepath, args) |
613 | 598 | ||
614 | # Load the tokenizer and add the placeholder token as a additional special token | 599 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
615 | if args.tokenizer_name: | 600 | args.pretrained_model_name_or_path) |
616 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 601 | |
617 | elif args.pretrained_model_name_or_path: | ||
618 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | ||
619 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 602 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
620 | tokenizer.set_dropout(args.vector_dropout) | 603 | tokenizer.set_dropout(args.vector_dropout) |
621 | 604 | ||
622 | # Load models and create wrapper for stable diffusion | ||
623 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | ||
624 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | ||
625 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | ||
626 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') | ||
627 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
628 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
629 | ema_unet = None | ||
630 | |||
631 | vae.enable_slicing() | 605 | vae.enable_slicing() |
632 | vae.set_use_memory_efficient_attention_xformers(True) | 606 | vae.set_use_memory_efficient_attention_xformers(True) |
633 | unet.set_use_memory_efficient_attention_xformers(True) | 607 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -636,16 +610,6 @@ def main(): | |||
636 | unet.enable_gradient_checkpointing() | 610 | unet.enable_gradient_checkpointing() |
637 | text_encoder.gradient_checkpointing_enable() | 611 | text_encoder.gradient_checkpointing_enable() |
638 | 612 | ||
639 | if args.use_ema: | ||
640 | ema_unet = EMAModel( | ||
641 | unet.parameters(), | ||
642 | inv_gamma=args.ema_inv_gamma, | ||
643 | power=args.ema_power, | ||
644 | max_value=args.ema_max_decay, | ||
645 | ) | ||
646 | |||
647 | embeddings = patch_managed_embeddings(text_encoder) | ||
648 | |||
649 | if args.embeddings_dir is not None: | 613 | if args.embeddings_dir is not None: |
650 | embeddings_dir = Path(args.embeddings_dir) | 614 | embeddings_dir = Path(args.embeddings_dir) |
651 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 615 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
@@ -654,24 +618,26 @@ def main(): | |||
654 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 618 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
655 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 619 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
656 | 620 | ||
657 | if len(args.placeholder_token) != 0: | 621 | placeholder_token_ids = add_placeholder_tokens( |
658 | # Convert the initializer_token, placeholder_token to ids | 622 | tokenizer=tokenizer, |
659 | initializer_token_ids = [ | 623 | embeddings=embeddings, |
660 | tokenizer.encode(token, add_special_tokens=False) | 624 | placeholder_tokens=args.placeholder_tokens, |
661 | for token in args.initializer_token | 625 | initializer_tokens=args.initializer_tokens, |
662 | ] | 626 | num_vectors=args.num_vectors |
663 | 627 | ) | |
664 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | ||
665 | embeddings.resize(len(tokenizer)) | ||
666 | |||
667 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids): | ||
668 | embeddings.add_embed(new_id, init_ids) | ||
669 | 628 | ||
670 | init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] | 629 | if len(placeholder_token_ids) != 0: |
630 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | ||
671 | 631 | ||
672 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") | 632 | if args.use_ema: |
633 | ema_unet = EMAModel( | ||
634 | unet.parameters(), | ||
635 | inv_gamma=args.ema_inv_gamma, | ||
636 | power=args.ema_power, | ||
637 | max_value=args.ema_max_decay, | ||
638 | ) | ||
673 | else: | 639 | else: |
674 | placeholder_token_id = [] | 640 | ema_unet = None |
675 | 641 | ||
676 | vae.requires_grad_(False) | 642 | vae.requires_grad_(False) |
677 | 643 | ||
@@ -765,8 +731,6 @@ def main(): | |||
765 | filter=keyword_filter, | 731 | filter=keyword_filter, |
766 | dtype=weight_dtype | 732 | dtype=weight_dtype |
767 | ) | 733 | ) |
768 | |||
769 | datamodule.prepare_data() | ||
770 | datamodule.setup() | 734 | datamodule.setup() |
771 | 735 | ||
772 | train_dataloader = datamodule.train_dataloader | 736 | train_dataloader = datamodule.train_dataloader |
@@ -779,7 +743,7 @@ def main(): | |||
779 | vae, | 743 | vae, |
780 | unet, | 744 | unet, |
781 | tokenizer, | 745 | tokenizer, |
782 | checkpoint_scheduler, | 746 | sample_scheduler, |
783 | datamodule.data_train, | 747 | datamodule.data_train, |
784 | args.sample_batch_size, | 748 | args.sample_batch_size, |
785 | args.sample_image_size, | 749 | args.sample_image_size, |
@@ -808,12 +772,8 @@ def main(): | |||
808 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 772 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
809 | ) | 773 | ) |
810 | 774 | ||
811 | # Move text_encoder and vae to device | ||
812 | vae.to(accelerator.device, dtype=weight_dtype) | 775 | vae.to(accelerator.device, dtype=weight_dtype) |
813 | 776 | ||
814 | # Keep text_encoder and vae in eval mode as we don't train these | ||
815 | vae.eval() | ||
816 | |||
817 | if args.use_ema: | 777 | if args.use_ema: |
818 | ema_unet.to(accelerator.device) | 778 | ema_unet.to(accelerator.device) |
819 | 779 | ||
@@ -877,17 +837,18 @@ def main(): | |||
877 | 837 | ||
878 | checkpointer = Checkpointer( | 838 | checkpointer = Checkpointer( |
879 | weight_dtype=weight_dtype, | 839 | weight_dtype=weight_dtype, |
880 | datamodule=datamodule, | 840 | train_dataloader=train_dataloader, |
841 | val_dataloader=val_dataloader, | ||
881 | accelerator=accelerator, | 842 | accelerator=accelerator, |
882 | vae=vae, | 843 | vae=vae, |
883 | unet=unet, | 844 | unet=unet, |
884 | ema_unet=ema_unet, | 845 | ema_unet=ema_unet, |
885 | tokenizer=tokenizer, | 846 | tokenizer=tokenizer, |
886 | text_encoder=text_encoder, | 847 | text_encoder=text_encoder, |
887 | scheduler=checkpoint_scheduler, | 848 | scheduler=sample_scheduler, |
849 | placeholder_tokens=args.placeholder_tokens, | ||
850 | placeholder_token_ids=placeholder_token_ids, | ||
888 | output_dir=basepath, | 851 | output_dir=basepath, |
889 | placeholder_token=args.placeholder_token, | ||
890 | placeholder_token_id=placeholder_token_id, | ||
891 | sample_image_size=args.sample_image_size, | 852 | sample_image_size=args.sample_image_size, |
892 | sample_batch_size=args.sample_batch_size, | 853 | sample_batch_size=args.sample_batch_size, |
893 | sample_batches=args.sample_batches, | 854 | sample_batches=args.sample_batches, |
@@ -895,23 +856,16 @@ def main(): | |||
895 | ) | 856 | ) |
896 | 857 | ||
897 | if accelerator.is_main_process: | 858 | if accelerator.is_main_process: |
898 | config = vars(args).copy() | ||
899 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
900 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
901 | if config["collection"] is not None: | ||
902 | config["collection"] = " ".join(config["collection"]) | ||
903 | if config["exclude_collections"] is not None: | ||
904 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | ||
905 | accelerator.init_trackers("dreambooth", config=config) | 859 | accelerator.init_trackers("dreambooth", config=config) |
906 | 860 | ||
907 | if args.find_lr: | 861 | if args.find_lr: |
908 | lr_finder = LRFinder( | 862 | lr_finder = LRFinder( |
909 | accelerator, | 863 | accelerator=accelerator, |
910 | text_encoder, | 864 | optimizer=optimizer, |
911 | optimizer, | 865 | model=unet, |
912 | train_dataloader, | 866 | train_dataloader=train_dataloader, |
913 | val_dataloader, | 867 | val_dataloader=val_dataloader, |
914 | loss_step_, | 868 | loss_step=loss_step_, |
915 | on_train=on_train, | 869 | on_train=on_train, |
916 | on_eval=on_eval, | 870 | on_eval=on_eval, |
917 | on_before_optimize=on_before_optimize, | 871 | on_before_optimize=on_before_optimize, |
@@ -921,29 +875,26 @@ def main(): | |||
921 | 875 | ||
922 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 876 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
923 | plt.close() | 877 | plt.close() |
924 | 878 | else: | |
925 | return | 879 | train_loop( |
926 | 880 | accelerator=accelerator, | |
927 | train_loop( | 881 | optimizer=optimizer, |
928 | accelerator=accelerator, | 882 | lr_scheduler=lr_scheduler, |
929 | optimizer=optimizer, | 883 | model=unet, |
930 | lr_scheduler=lr_scheduler, | 884 | checkpointer=checkpointer, |
931 | model=unet, | 885 | train_dataloader=train_dataloader, |
932 | checkpointer=checkpointer, | 886 | val_dataloader=val_dataloader, |
933 | train_dataloader=train_dataloader, | 887 | loss_step=loss_step_, |
934 | val_dataloader=val_dataloader, | 888 | sample_frequency=args.sample_frequency, |
935 | loss_step=loss_step_, | 889 | sample_steps=args.sample_steps, |
936 | sample_frequency=args.sample_frequency, | 890 | checkpoint_frequency=args.checkpoint_frequency, |
937 | sample_steps=args.sample_steps, | 891 | global_step_offset=0, |
938 | checkpoint_frequency=args.checkpoint_frequency, | 892 | num_epochs=args.num_train_epochs, |
939 | global_step_offset=0, | 893 | on_log=on_log, |
940 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 894 | on_train=on_train, |
941 | num_epochs=args.num_train_epochs, | 895 | on_after_optimize=on_after_optimize, |
942 | on_log=on_log, | 896 | on_eval=on_eval |
943 | on_train=on_train, | 897 | ) |
944 | on_after_optimize=on_after_optimize, | ||
945 | on_eval=on_eval | ||
946 | ) | ||
947 | 898 | ||
948 | 899 | ||
949 | if __name__ == "__main__": | 900 | if __name__ == "__main__": |
diff --git a/train_ti.py b/train_ti.py index 61195f6..d2ca7eb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -492,7 +492,7 @@ def parse_args(): | |||
492 | class Checkpointer(CheckpointerBase): | 492 | class Checkpointer(CheckpointerBase): |
493 | def __init__( | 493 | def __init__( |
494 | self, | 494 | self, |
495 | weight_dtype, | 495 | weight_dtype: torch.dtype, |
496 | accelerator: Accelerator, | 496 | accelerator: Accelerator, |
497 | vae: AutoencoderKL, | 497 | vae: AutoencoderKL, |
498 | unet: UNet2DConditionModel, | 498 | unet: UNet2DConditionModel, |
@@ -587,7 +587,9 @@ def main(): | |||
587 | 587 | ||
588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
589 | 589 | ||
590 | args.seed = args.seed or (torch.random.seed() >> 32) | 590 | if args.seed is None: |
591 | args.seed = torch.random.seed() >> 32 | ||
592 | |||
591 | set_seed(args.seed) | 593 | set_seed(args.seed) |
592 | 594 | ||
593 | save_args(basepath, args) | 595 | save_args(basepath, args) |
@@ -622,7 +624,8 @@ def main(): | |||
622 | num_vectors=args.num_vectors | 624 | num_vectors=args.num_vectors |
623 | ) | 625 | ) |
624 | 626 | ||
625 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 627 | if len(placeholder_token_ids) != 0: |
628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | ||
626 | 629 | ||
627 | if args.use_ema: | 630 | if args.use_ema: |
628 | ema_embeddings = EMAModel( | 631 | ema_embeddings = EMAModel( |