summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 23:09:14 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 23:09:14 +0100
commit6ecfdb73d150c5a596722ec3234e53f4796fbc78 (patch)
tree797bc01768f71a74f944bf1bf18e9bf62665ee4e /train_dreambooth.py
parentReverted modularization mostly (diff)
downloadtextual-inversion-diff-6ecfdb73d150c5a596722ec3234e53f4796fbc78.tar.gz
textual-inversion-diff-6ecfdb73d150c5a596722ec3234e53f4796fbc78.tar.bz2
textual-inversion-diff-6ecfdb73d150c5a596722ec3234e53f4796fbc78.zip
Unified training script structure
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py205
1 files changed, 78 insertions, 127 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2145e2b..a1802a0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -12,20 +12,18 @@ import torch.utils.checkpoint
12from accelerate import Accelerator 12from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, UNet2DConditionModel
16import matplotlib.pyplot as plt 16import matplotlib.pyplot as plt
17from diffusers.training_utils import EMAModel
18from transformers import CLIPTextModel 17from transformers import CLIPTextModel
19from slugify import slugify 18from slugify import slugify
20 19
21from util import load_config, load_embeddings_from_dir 20from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import VlpnDataModule, VlpnDataItem 22from data.csv import VlpnDataModule, VlpnDataItem
24from training.common import loss_step, train_loop, generate_class_images 23from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
25from training.optimization import get_scheduler 24from training.optimization import get_scheduler
26from training.lr import LRFinder 25from training.lr import LRFinder
27from training.util import CheckpointerBase, save_args 26from training.util import CheckpointerBase, EMAModel, save_args
28from models.clip.embeddings import patch_managed_embeddings
29from models.clip.tokenizer import MultiCLIPTokenizer 27from models.clip.tokenizer import MultiCLIPTokenizer
30 28
31logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -69,7 +67,7 @@ def parse_args():
69 help="The name of the current project.", 67 help="The name of the current project.",
70 ) 68 )
71 parser.add_argument( 69 parser.add_argument(
72 "--placeholder_token", 70 "--placeholder_tokens",
73 type=str, 71 type=str,
74 nargs='*', 72 nargs='*',
75 default=[], 73 default=[],
@@ -446,20 +444,20 @@ def parse_args():
446 if args.project is None: 444 if args.project is None:
447 raise ValueError("You must specify --project") 445 raise ValueError("You must specify --project")
448 446
449 if isinstance(args.placeholder_token, str): 447 if isinstance(args.placeholder_tokens, str):
450 args.placeholder_token = [args.placeholder_token] 448 args.placeholder_tokens = [args.placeholder_tokens]
451 449
452 if len(args.placeholder_token) == 0: 450 if len(args.placeholder_tokens) == 0:
453 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 451 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)]
454 452
455 if isinstance(args.initializer_token, str): 453 if isinstance(args.initializer_token, str):
456 args.initializer_token = [args.initializer_token] * len(args.placeholder_token) 454 args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens)
457 455
458 if len(args.initializer_token) == 0: 456 if len(args.initializer_token) == 0:
459 raise ValueError("You must specify --initializer_token") 457 raise ValueError("You must specify --initializer_token")
460 458
461 if len(args.placeholder_token) != len(args.initializer_token): 459 if len(args.placeholder_tokens) != len(args.initializer_token):
462 raise ValueError("--placeholder_token and --initializer_token must have the same number of items") 460 raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items")
463 461
464 if args.num_vectors is None: 462 if args.num_vectors is None:
465 args.num_vectors = 1 463 args.num_vectors = 1
@@ -467,8 +465,8 @@ def parse_args():
467 if isinstance(args.num_vectors, int): 465 if isinstance(args.num_vectors, int):
468 args.num_vectors = [args.num_vectors] * len(args.initializer_token) 466 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
469 467
470 if len(args.placeholder_token) != len(args.num_vectors): 468 if len(args.placeholder_tokens) != len(args.num_vectors):
471 raise ValueError("--placeholder_token and --num_vectors must have the same number of items") 469 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
472 470
473 if isinstance(args.collection, str): 471 if isinstance(args.collection, str):
474 args.collection = [args.collection] 472 args.collection = [args.collection]
@@ -485,33 +483,18 @@ def parse_args():
485class Checkpointer(CheckpointerBase): 483class Checkpointer(CheckpointerBase):
486 def __init__( 484 def __init__(
487 self, 485 self,
488 weight_dtype, 486 weight_dtype: torch.dtype,
489 datamodule, 487 accelerator: Accelerator,
490 accelerator, 488 vae: AutoencoderKL,
491 vae, 489 unet: UNet2DConditionModel,
492 unet, 490 ema_unet: EMAModel,
493 ema_unet, 491 tokenizer: MultiCLIPTokenizer,
494 tokenizer, 492 text_encoder: CLIPTextModel,
495 text_encoder,
496 scheduler, 493 scheduler,
497 output_dir: Path, 494 *args,
498 placeholder_token, 495 **kwargs
499 placeholder_token_id,
500 sample_image_size,
501 sample_batches,
502 sample_batch_size,
503 seed,
504 ): 496 ):
505 super().__init__( 497 super().__init__(*args, **kwargs)
506 datamodule=datamodule,
507 output_dir=output_dir,
508 placeholder_token=placeholder_token,
509 placeholder_token_id=placeholder_token_id,
510 sample_image_size=sample_image_size,
511 seed=seed or torch.random.seed(),
512 sample_batches=sample_batches,
513 sample_batch_size=sample_batch_size
514 )
515 498
516 self.weight_dtype = weight_dtype 499 self.weight_dtype = weight_dtype
517 self.accelerator = accelerator 500 self.accelerator = accelerator
@@ -606,28 +589,19 @@ def main():
606 589
607 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 590 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
608 591
609 args.seed = args.seed or (torch.random.seed() >> 32) 592 if args.seed is None:
593 args.seed = torch.random.seed() >> 32
594
610 set_seed(args.seed) 595 set_seed(args.seed)
611 596
612 save_args(basepath, args) 597 save_args(basepath, args)
613 598
614 # Load the tokenizer and add the placeholder token as a additional special token 599 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
615 if args.tokenizer_name: 600 args.pretrained_model_name_or_path)
616 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) 601
617 elif args.pretrained_model_name_or_path:
618 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
619 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 602 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
620 tokenizer.set_dropout(args.vector_dropout) 603 tokenizer.set_dropout(args.vector_dropout)
621 604
622 # Load models and create wrapper for stable diffusion
623 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
624 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
625 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
626 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
627 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
628 args.pretrained_model_name_or_path, subfolder='scheduler')
629 ema_unet = None
630
631 vae.enable_slicing() 605 vae.enable_slicing()
632 vae.set_use_memory_efficient_attention_xformers(True) 606 vae.set_use_memory_efficient_attention_xformers(True)
633 unet.set_use_memory_efficient_attention_xformers(True) 607 unet.set_use_memory_efficient_attention_xformers(True)
@@ -636,16 +610,6 @@ def main():
636 unet.enable_gradient_checkpointing() 610 unet.enable_gradient_checkpointing()
637 text_encoder.gradient_checkpointing_enable() 611 text_encoder.gradient_checkpointing_enable()
638 612
639 if args.use_ema:
640 ema_unet = EMAModel(
641 unet.parameters(),
642 inv_gamma=args.ema_inv_gamma,
643 power=args.ema_power,
644 max_value=args.ema_max_decay,
645 )
646
647 embeddings = patch_managed_embeddings(text_encoder)
648
649 if args.embeddings_dir is not None: 613 if args.embeddings_dir is not None:
650 embeddings_dir = Path(args.embeddings_dir) 614 embeddings_dir = Path(args.embeddings_dir)
651 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 615 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
@@ -654,24 +618,26 @@ def main():
654 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 618 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
655 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 619 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
656 620
657 if len(args.placeholder_token) != 0: 621 placeholder_token_ids = add_placeholder_tokens(
658 # Convert the initializer_token, placeholder_token to ids 622 tokenizer=tokenizer,
659 initializer_token_ids = [ 623 embeddings=embeddings,
660 tokenizer.encode(token, add_special_tokens=False) 624 placeholder_tokens=args.placeholder_tokens,
661 for token in args.initializer_token 625 initializer_tokens=args.initializer_tokens,
662 ] 626 num_vectors=args.num_vectors
663 627 )
664 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
665 embeddings.resize(len(tokenizer))
666
667 for (new_id, init_ids) in zip(new_ids, initializer_token_ids):
668 embeddings.add_embed(new_id, init_ids)
669 628
670 init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] 629 if len(placeholder_token_ids) != 0:
630 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}")
671 631
672 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") 632 if args.use_ema:
633 ema_unet = EMAModel(
634 unet.parameters(),
635 inv_gamma=args.ema_inv_gamma,
636 power=args.ema_power,
637 max_value=args.ema_max_decay,
638 )
673 else: 639 else:
674 placeholder_token_id = [] 640 ema_unet = None
675 641
676 vae.requires_grad_(False) 642 vae.requires_grad_(False)
677 643
@@ -765,8 +731,6 @@ def main():
765 filter=keyword_filter, 731 filter=keyword_filter,
766 dtype=weight_dtype 732 dtype=weight_dtype
767 ) 733 )
768
769 datamodule.prepare_data()
770 datamodule.setup() 734 datamodule.setup()
771 735
772 train_dataloader = datamodule.train_dataloader 736 train_dataloader = datamodule.train_dataloader
@@ -779,7 +743,7 @@ def main():
779 vae, 743 vae,
780 unet, 744 unet,
781 tokenizer, 745 tokenizer,
782 checkpoint_scheduler, 746 sample_scheduler,
783 datamodule.data_train, 747 datamodule.data_train,
784 args.sample_batch_size, 748 args.sample_batch_size,
785 args.sample_image_size, 749 args.sample_image_size,
@@ -808,12 +772,8 @@ def main():
808 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 772 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
809 ) 773 )
810 774
811 # Move text_encoder and vae to device
812 vae.to(accelerator.device, dtype=weight_dtype) 775 vae.to(accelerator.device, dtype=weight_dtype)
813 776
814 # Keep text_encoder and vae in eval mode as we don't train these
815 vae.eval()
816
817 if args.use_ema: 777 if args.use_ema:
818 ema_unet.to(accelerator.device) 778 ema_unet.to(accelerator.device)
819 779
@@ -877,17 +837,18 @@ def main():
877 837
878 checkpointer = Checkpointer( 838 checkpointer = Checkpointer(
879 weight_dtype=weight_dtype, 839 weight_dtype=weight_dtype,
880 datamodule=datamodule, 840 train_dataloader=train_dataloader,
841 val_dataloader=val_dataloader,
881 accelerator=accelerator, 842 accelerator=accelerator,
882 vae=vae, 843 vae=vae,
883 unet=unet, 844 unet=unet,
884 ema_unet=ema_unet, 845 ema_unet=ema_unet,
885 tokenizer=tokenizer, 846 tokenizer=tokenizer,
886 text_encoder=text_encoder, 847 text_encoder=text_encoder,
887 scheduler=checkpoint_scheduler, 848 scheduler=sample_scheduler,
849 placeholder_tokens=args.placeholder_tokens,
850 placeholder_token_ids=placeholder_token_ids,
888 output_dir=basepath, 851 output_dir=basepath,
889 placeholder_token=args.placeholder_token,
890 placeholder_token_id=placeholder_token_id,
891 sample_image_size=args.sample_image_size, 852 sample_image_size=args.sample_image_size,
892 sample_batch_size=args.sample_batch_size, 853 sample_batch_size=args.sample_batch_size,
893 sample_batches=args.sample_batches, 854 sample_batches=args.sample_batches,
@@ -895,23 +856,16 @@ def main():
895 ) 856 )
896 857
897 if accelerator.is_main_process: 858 if accelerator.is_main_process:
898 config = vars(args).copy()
899 config["initializer_token"] = " ".join(config["initializer_token"])
900 config["placeholder_token"] = " ".join(config["placeholder_token"])
901 if config["collection"] is not None:
902 config["collection"] = " ".join(config["collection"])
903 if config["exclude_collections"] is not None:
904 config["exclude_collections"] = " ".join(config["exclude_collections"])
905 accelerator.init_trackers("dreambooth", config=config) 859 accelerator.init_trackers("dreambooth", config=config)
906 860
907 if args.find_lr: 861 if args.find_lr:
908 lr_finder = LRFinder( 862 lr_finder = LRFinder(
909 accelerator, 863 accelerator=accelerator,
910 text_encoder, 864 optimizer=optimizer,
911 optimizer, 865 model=unet,
912 train_dataloader, 866 train_dataloader=train_dataloader,
913 val_dataloader, 867 val_dataloader=val_dataloader,
914 loss_step_, 868 loss_step=loss_step_,
915 on_train=on_train, 869 on_train=on_train,
916 on_eval=on_eval, 870 on_eval=on_eval,
917 on_before_optimize=on_before_optimize, 871 on_before_optimize=on_before_optimize,
@@ -921,29 +875,26 @@ def main():
921 875
922 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 876 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
923 plt.close() 877 plt.close()
924 878 else:
925 return 879 train_loop(
926 880 accelerator=accelerator,
927 train_loop( 881 optimizer=optimizer,
928 accelerator=accelerator, 882 lr_scheduler=lr_scheduler,
929 optimizer=optimizer, 883 model=unet,
930 lr_scheduler=lr_scheduler, 884 checkpointer=checkpointer,
931 model=unet, 885 train_dataloader=train_dataloader,
932 checkpointer=checkpointer, 886 val_dataloader=val_dataloader,
933 train_dataloader=train_dataloader, 887 loss_step=loss_step_,
934 val_dataloader=val_dataloader, 888 sample_frequency=args.sample_frequency,
935 loss_step=loss_step_, 889 sample_steps=args.sample_steps,
936 sample_frequency=args.sample_frequency, 890 checkpoint_frequency=args.checkpoint_frequency,
937 sample_steps=args.sample_steps, 891 global_step_offset=0,
938 checkpoint_frequency=args.checkpoint_frequency, 892 num_epochs=args.num_train_epochs,
939 global_step_offset=0, 893 on_log=on_log,
940 gradient_accumulation_steps=args.gradient_accumulation_steps, 894 on_train=on_train,
941 num_epochs=args.num_train_epochs, 895 on_after_optimize=on_after_optimize,
942 on_log=on_log, 896 on_eval=on_eval
943 on_train=on_train, 897 )
944 on_after_optimize=on_after_optimize,
945 on_eval=on_eval
946 )
947 898
948 899
949if __name__ == "__main__": 900if __name__ == "__main__":