diff options
| -rw-r--r-- | infer.py | 1 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 7 | ||||
| -rw-r--r-- | models/sparse.py | 13 | ||||
| -rw-r--r-- | train_dreambooth.py | 1 | ||||
| -rw-r--r-- | train_lora.py | 140 | ||||
| -rw-r--r-- | train_ti.py | 66 | ||||
| -rw-r--r-- | training/functional.py | 4 |
7 files changed, 72 insertions, 160 deletions
| @@ -235,6 +235,7 @@ def load_embeddings(pipeline, embeddings_dir): | |||
| 235 | pipeline.text_encoder.text_model.embeddings, | 235 | pipeline.text_encoder.text_model.embeddings, |
| 236 | Path(embeddings_dir) | 236 | Path(embeddings_dir) |
| 237 | ) | 237 | ) |
| 238 | pipeline.text_encoder.text_model.embeddings.persist() | ||
| 238 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 239 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 239 | 240 | ||
| 240 | 241 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6fda33c..dc4708a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 37 | 37 | ||
| 38 | 38 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): |
| 41 | super().__init__(config) | 41 | super().__init__(config) |
| 42 | 42 | ||
| 43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
| @@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 46 | 46 | ||
| 47 | self.token_override_embedding = PseudoSparseEmbedding( | 47 | self.token_override_embedding = PseudoSparseEmbedding( |
| 48 | self.token_embedding.embedding_dim, | 48 | self.token_embedding.embedding_dim, |
| 49 | dropout_p=dropout_p, | ||
| 49 | device=self.token_embedding.weight.device, | 50 | device=self.token_embedding.weight.device, |
| 50 | dtype=self.token_embedding.weight.dtype, | 51 | dtype=self.token_embedding.weight.dtype, |
| 51 | ) | 52 | ) |
| @@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 134 | return embeddings | 135 | return embeddings |
| 135 | 136 | ||
| 136 | 137 | ||
| 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | 138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: |
| 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) |
| 139 | text_encoder.text_model.embeddings = text_embeddings | 140 | text_encoder.text_model.embeddings = text_embeddings |
| 140 | return text_embeddings | 141 | return text_embeddings |
diff --git a/models/sparse.py b/models/sparse.py index d706db5..bcb2897 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -5,22 +5,29 @@ import torch.nn as nn | |||
| 5 | 5 | ||
| 6 | 6 | ||
| 7 | class PseudoSparseEmbedding(nn.Module): | 7 | class PseudoSparseEmbedding(nn.Module): |
| 8 | def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): | 8 | def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): |
| 9 | super().__init__() | 9 | super().__init__() |
| 10 | 10 | ||
| 11 | self.embedding_dim = embedding_dim | 11 | self.embedding_dim = embedding_dim |
| 12 | self.dtype = dtype | 12 | self.dtype = dtype |
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | |||
| 15 | if dropout_p > 0.0: | ||
| 16 | self.dropout = nn.Dropout(p=dropout_p) | ||
| 17 | else: | ||
| 18 | self.dropout = lambda x: x | ||
| 19 | |||
| 14 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) | 20 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) |
| 15 | 21 | ||
| 16 | def forward(self, input_ids: torch.LongTensor): | 22 | def forward(self, input_ids: torch.LongTensor): |
| 17 | ids = self.mapping[input_ids.to(self.mapping.device)] | 23 | input_ids = input_ids.to(self.mapping.device) |
| 24 | ids = self.mapping[input_ids] | ||
| 18 | mask = ~(ids == -1) | 25 | mask = ~(ids == -1) |
| 19 | 26 | ||
| 20 | if torch.all(~mask): | 27 | if torch.all(~mask): |
| 21 | embs = None | 28 | embs = None |
| 22 | else: | 29 | else: |
| 23 | embs = torch.stack([self.params[id] for id in ids[mask]]) | 30 | embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) |
| 24 | 31 | ||
| 25 | return embs, mask | 32 | return embs, mask |
| 26 | 33 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index f4d4cbb..2aca1e7 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -513,6 +513,7 @@ def main(): | |||
| 513 | raise ValueError("--embeddings_dir must point to an existing directory") | 513 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 514 | 514 | ||
| 515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 516 | embeddings.persist() | ||
| 516 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 517 | 518 | ||
| 518 | if args.scale_lr: | 519 | if args.scale_lr: |
diff --git a/train_lora.py b/train_lora.py index 8dbe45b..6e21634 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -159,12 +159,6 @@ def parse_args(): | |||
| 159 | help="Tag dropout probability.", | 159 | help="Tag dropout probability.", |
| 160 | ) | 160 | ) |
| 161 | parser.add_argument( | 161 | parser.add_argument( |
| 162 | "--pti_tag_dropout", | ||
| 163 | type=float, | ||
| 164 | default=0, | ||
| 165 | help="Tag dropout probability.", | ||
| 166 | ) | ||
| 167 | parser.add_argument( | ||
| 168 | "--no_tag_shuffle", | 162 | "--no_tag_shuffle", |
| 169 | action="store_true", | 163 | action="store_true", |
| 170 | help="Shuffle tags.", | 164 | help="Shuffle tags.", |
| @@ -236,28 +230,12 @@ def parse_args(): | |||
| 236 | default=2000 | 230 | default=2000 |
| 237 | ) | 231 | ) |
| 238 | parser.add_argument( | 232 | parser.add_argument( |
| 239 | "--num_pti_epochs", | ||
| 240 | type=int, | ||
| 241 | default=None | ||
| 242 | ) | ||
| 243 | parser.add_argument( | ||
| 244 | "--num_pti_steps", | ||
| 245 | type=int, | ||
| 246 | default=500 | ||
| 247 | ) | ||
| 248 | parser.add_argument( | ||
| 249 | "--gradient_accumulation_steps", | 233 | "--gradient_accumulation_steps", |
| 250 | type=int, | 234 | type=int, |
| 251 | default=1, | 235 | default=1, |
| 252 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 236 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 253 | ) | 237 | ) |
| 254 | parser.add_argument( | 238 | parser.add_argument( |
| 255 | "--pti_gradient_accumulation_steps", | ||
| 256 | type=int, | ||
| 257 | default=1, | ||
| 258 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
| 259 | ) | ||
| 260 | parser.add_argument( | ||
| 261 | "--lora_r", | 239 | "--lora_r", |
| 262 | type=int, | 240 | type=int, |
| 263 | default=8, | 241 | default=8, |
| @@ -323,12 +301,6 @@ def parse_args(): | |||
| 323 | help="Initial learning rate (after the potential warmup period) to use.", | 301 | help="Initial learning rate (after the potential warmup period) to use.", |
| 324 | ) | 302 | ) |
| 325 | parser.add_argument( | 303 | parser.add_argument( |
| 326 | "--learning_rate_pti", | ||
| 327 | type=float, | ||
| 328 | default=1e-4, | ||
| 329 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 330 | ) | ||
| 331 | parser.add_argument( | ||
| 332 | "--learning_rate_emb", | 304 | "--learning_rate_emb", |
| 333 | type=float, | 305 | type=float, |
| 334 | default=1e-5, | 306 | default=1e-5, |
| @@ -467,12 +439,6 @@ def parse_args(): | |||
| 467 | help="How often to save a checkpoint and sample image", | 439 | help="How often to save a checkpoint and sample image", |
| 468 | ) | 440 | ) |
| 469 | parser.add_argument( | 441 | parser.add_argument( |
| 470 | "--pti_sample_frequency", | ||
| 471 | type=int, | ||
| 472 | default=1, | ||
| 473 | help="How often to save a checkpoint and sample image", | ||
| 474 | ) | ||
| 475 | parser.add_argument( | ||
| 476 | "--sample_image_size", | 442 | "--sample_image_size", |
| 477 | type=int, | 443 | type=int, |
| 478 | default=768, | 444 | default=768, |
| @@ -509,12 +475,6 @@ def parse_args(): | |||
| 509 | help="Batch size (per device) for the training dataloader." | 475 | help="Batch size (per device) for the training dataloader." |
| 510 | ) | 476 | ) |
| 511 | parser.add_argument( | 477 | parser.add_argument( |
| 512 | "--pti_batch_size", | ||
| 513 | type=int, | ||
| 514 | default=1, | ||
| 515 | help="Batch size (per device) for the training dataloader." | ||
| 516 | ) | ||
| 517 | parser.add_argument( | ||
| 518 | "--sample_steps", | 478 | "--sample_steps", |
| 519 | type=int, | 479 | type=int, |
| 520 | default=10, | 480 | default=10, |
| @@ -527,6 +487,12 @@ def parse_args(): | |||
| 527 | help="The weight of prior preservation loss." | 487 | help="The weight of prior preservation loss." |
| 528 | ) | 488 | ) |
| 529 | parser.add_argument( | 489 | parser.add_argument( |
| 490 | "--emb_dropout", | ||
| 491 | type=float, | ||
| 492 | default=0, | ||
| 493 | help="Embedding dropout probability.", | ||
| 494 | ) | ||
| 495 | parser.add_argument( | ||
| 530 | "--use_emb_decay", | 496 | "--use_emb_decay", |
| 531 | action="store_true", | 497 | action="store_true", |
| 532 | help="Whether to use embedding decay." | 498 | help="Whether to use embedding decay." |
| @@ -674,7 +640,7 @@ def main(): | |||
| 674 | save_args(output_dir, args) | 640 | save_args(output_dir, args) |
| 675 | 641 | ||
| 676 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 642 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 677 | args.pretrained_model_name_or_path) | 643 | args.pretrained_model_name_or_path, args.emb_dropout) |
| 678 | 644 | ||
| 679 | unet_config = LoraConfig( | 645 | unet_config = LoraConfig( |
| 680 | r=args.lora_r, | 646 | r=args.lora_r, |
| @@ -720,6 +686,7 @@ def main(): | |||
| 720 | raise ValueError("--embeddings_dir must point to an existing directory") | 686 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 721 | 687 | ||
| 722 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 688 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 689 | embeddings.persist() | ||
| 723 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 690 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 724 | 691 | ||
| 725 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 692 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| @@ -744,19 +711,14 @@ def main(): | |||
| 744 | args.learning_rate_text * args.gradient_accumulation_steps * | 711 | args.learning_rate_text * args.gradient_accumulation_steps * |
| 745 | args.train_batch_size * accelerator.num_processes | 712 | args.train_batch_size * accelerator.num_processes |
| 746 | ) | 713 | ) |
| 747 | args.learning_rate_pti = ( | ||
| 748 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * | ||
| 749 | args.pti_batch_size * accelerator.num_processes | ||
| 750 | ) | ||
| 751 | args.learning_rate_emb = ( | 714 | args.learning_rate_emb = ( |
| 752 | args.learning_rate_emb * args.pti_gradient_accumulation_steps * | 715 | args.learning_rate_emb * args.gradient_accumulation_steps * |
| 753 | args.pti_batch_size * accelerator.num_processes | 716 | args.train_batch_size * accelerator.num_processes |
| 754 | ) | 717 | ) |
| 755 | 718 | ||
| 756 | if args.find_lr: | 719 | if args.find_lr: |
| 757 | args.learning_rate_unet = 1e-6 | 720 | args.learning_rate_unet = 1e-6 |
| 758 | args.learning_rate_text = 1e-6 | 721 | args.learning_rate_text = 1e-6 |
| 759 | args.learning_rate_pti = 1e-6 | ||
| 760 | args.learning_rate_emb = 1e-6 | 722 | args.learning_rate_emb = 1e-6 |
| 761 | args.lr_scheduler = "exponential_growth" | 723 | args.lr_scheduler = "exponential_growth" |
| 762 | 724 | ||
| @@ -817,7 +779,6 @@ def main(): | |||
| 817 | args.lr_min_lr = args.learning_rate_unet | 779 | args.lr_min_lr = args.learning_rate_unet |
| 818 | args.learning_rate_unet = None | 780 | args.learning_rate_unet = None |
| 819 | args.learning_rate_text = None | 781 | args.learning_rate_text = None |
| 820 | args.learning_rate_pti = None | ||
| 821 | args.learning_rate_emb = None | 782 | args.learning_rate_emb = None |
| 822 | elif args.optimizer == 'dadam': | 783 | elif args.optimizer == 'dadam': |
| 823 | try: | 784 | try: |
| @@ -836,7 +797,6 @@ def main(): | |||
| 836 | 797 | ||
| 837 | args.learning_rate_unet = 1.0 | 798 | args.learning_rate_unet = 1.0 |
| 838 | args.learning_rate_text = 1.0 | 799 | args.learning_rate_text = 1.0 |
| 839 | args.learning_rate_pti = 1.0 | ||
| 840 | args.learning_rate_emb = 1.0 | 800 | args.learning_rate_emb = 1.0 |
| 841 | elif args.optimizer == 'dadan': | 801 | elif args.optimizer == 'dadan': |
| 842 | try: | 802 | try: |
| @@ -853,7 +813,6 @@ def main(): | |||
| 853 | 813 | ||
| 854 | args.learning_rate_unet = 1.0 | 814 | args.learning_rate_unet = 1.0 |
| 855 | args.learning_rate_text = 1.0 | 815 | args.learning_rate_text = 1.0 |
| 856 | args.learning_rate_pti = 1.0 | ||
| 857 | args.learning_rate_emb = 1.0 | 816 | args.learning_rate_emb = 1.0 |
| 858 | else: | 817 | else: |
| 859 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 818 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| @@ -920,80 +879,6 @@ def main(): | |||
| 920 | mid_point=args.lr_mid_point, | 879 | mid_point=args.lr_mid_point, |
| 921 | ) | 880 | ) |
| 922 | 881 | ||
| 923 | # PTI | ||
| 924 | # -------------------------------------------------------------------------------- | ||
| 925 | |||
| 926 | if len(args.placeholder_tokens) != 0: | ||
| 927 | pti_datamodule = create_datamodule( | ||
| 928 | batch_size=args.pti_batch_size, | ||
| 929 | dropout=args.pti_tag_dropout, | ||
| 930 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), | ||
| 931 | ) | ||
| 932 | pti_datamodule.setup() | ||
| 933 | |||
| 934 | num_pti_epochs = args.num_pti_epochs | ||
| 935 | pti_sample_frequency = args.pti_sample_frequency | ||
| 936 | if num_pti_epochs is None: | ||
| 937 | num_pti_epochs = math.ceil( | ||
| 938 | args.num_pti_steps / len(pti_datamodule.train_dataset) | ||
| 939 | ) * args.pti_gradient_accumulation_steps | ||
| 940 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) | ||
| 941 | |||
| 942 | if num_pti_epochs > 0: | ||
| 943 | pti_optimizer = create_optimizer( | ||
| 944 | [ | ||
| 945 | { | ||
| 946 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
| 947 | "lr": args.learning_rate_pti, | ||
| 948 | "weight_decay": 0, | ||
| 949 | }, | ||
| 950 | ] | ||
| 951 | ) | ||
| 952 | |||
| 953 | pti_lr_scheduler = create_lr_scheduler( | ||
| 954 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
| 955 | optimizer=pti_optimizer, | ||
| 956 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
| 957 | train_epochs=num_pti_epochs, | ||
| 958 | ) | ||
| 959 | |||
| 960 | continue_training = True | ||
| 961 | training_iter = 1 | ||
| 962 | |||
| 963 | while continue_training: | ||
| 964 | print("") | ||
| 965 | print(f"============ PTI cycle {training_iter} ============") | ||
| 966 | print("") | ||
| 967 | |||
| 968 | pti_project = f"pti_{training_iter}" | ||
| 969 | pti_output_dir = output_dir / pti_project | ||
| 970 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
| 971 | pti_sample_output_dir = pti_output_dir / "samples" | ||
| 972 | |||
| 973 | trainer( | ||
| 974 | strategy=lora_strategy, | ||
| 975 | pti_mode=True, | ||
| 976 | project=pti_project, | ||
| 977 | train_dataloader=pti_datamodule.train_dataloader, | ||
| 978 | val_dataloader=pti_datamodule.val_dataloader, | ||
| 979 | optimizer=pti_optimizer, | ||
| 980 | lr_scheduler=pti_lr_scheduler, | ||
| 981 | num_train_epochs=num_pti_epochs, | ||
| 982 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
| 983 | # -- | ||
| 984 | group_labels=["emb"], | ||
| 985 | sample_output_dir=pti_sample_output_dir, | ||
| 986 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
| 987 | sample_frequency=pti_sample_frequency, | ||
| 988 | ) | ||
| 989 | |||
| 990 | response = input("Run another cycle? [y/n] ") | ||
| 991 | continue_training = response.lower().strip() != "n" | ||
| 992 | training_iter += 1 | ||
| 993 | |||
| 994 | if not args.train_emb: | ||
| 995 | embeddings.persist() | ||
| 996 | |||
| 997 | # LORA | 882 | # LORA |
| 998 | # -------------------------------------------------------------------------------- | 883 | # -------------------------------------------------------------------------------- |
| 999 | 884 | ||
| @@ -1062,9 +947,8 @@ def main(): | |||
| 1062 | print("") | 947 | print("") |
| 1063 | 948 | ||
| 1064 | lora_project = f"lora_{training_iter}" | 949 | lora_project = f"lora_{training_iter}" |
| 1065 | lora_output_dir = output_dir / lora_project | 950 | lora_checkpoint_output_dir = output_dir / lora_project / "model" |
| 1066 | lora_checkpoint_output_dir = lora_output_dir / "model" | 951 | lora_sample_output_dir = output_dir / lora_project / "samples" |
| 1067 | lora_sample_output_dir = lora_output_dir / "samples" | ||
| 1068 | 952 | ||
| 1069 | trainer( | 953 | trainer( |
| 1070 | strategy=lora_strategy, | 954 | strategy=lora_strategy, |
diff --git a/train_ti.py b/train_ti.py index daf8bc5..2d51800 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -458,6 +458,12 @@ def parse_args(): | |||
| 458 | help="The weight of prior preservation loss." | 458 | help="The weight of prior preservation loss." |
| 459 | ) | 459 | ) |
| 460 | parser.add_argument( | 460 | parser.add_argument( |
| 461 | "--emb_dropout", | ||
| 462 | type=float, | ||
| 463 | default=0, | ||
| 464 | help="Embedding dropout probability.", | ||
| 465 | ) | ||
| 466 | parser.add_argument( | ||
| 461 | "--use_emb_decay", | 467 | "--use_emb_decay", |
| 462 | action="store_true", | 468 | action="store_true", |
| 463 | help="Whether to use embedding decay." | 469 | help="Whether to use embedding decay." |
| @@ -624,7 +630,7 @@ def main(): | |||
| 624 | save_args(output_dir, args) | 630 | save_args(output_dir, args) |
| 625 | 631 | ||
| 626 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 632 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 627 | args.pretrained_model_name_or_path) | 633 | args.pretrained_model_name_or_path, args.emb_dropout) |
| 628 | 634 | ||
| 629 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 635 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 630 | tokenizer.set_dropout(args.vector_dropout) | 636 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -755,8 +761,6 @@ def main(): | |||
| 755 | else: | 761 | else: |
| 756 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 762 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 757 | 763 | ||
| 758 | checkpoint_output_dir = output_dir / "checkpoints" | ||
| 759 | |||
| 760 | trainer = partial( | 764 | trainer = partial( |
| 761 | train, | 765 | train, |
| 762 | accelerator=accelerator, | 766 | accelerator=accelerator, |
| @@ -777,7 +781,6 @@ def main(): | |||
| 777 | global_step_offset=global_step_offset, | 781 | global_step_offset=global_step_offset, |
| 778 | offset_noise_strength=args.offset_noise_strength, | 782 | offset_noise_strength=args.offset_noise_strength, |
| 779 | # -- | 783 | # -- |
| 780 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 781 | use_emb_decay=args.use_emb_decay, | 784 | use_emb_decay=args.use_emb_decay, |
| 782 | emb_decay_target=args.emb_decay_target, | 785 | emb_decay_target=args.emb_decay_target, |
| 783 | emb_decay=args.emb_decay, | 786 | emb_decay=args.emb_decay, |
| @@ -793,11 +796,6 @@ def main(): | |||
| 793 | ) | 796 | ) |
| 794 | 797 | ||
| 795 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 798 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
| 796 | if len(placeholder_tokens) == 1: | ||
| 797 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | ||
| 798 | else: | ||
| 799 | sample_output_dir = output_dir / "samples" | ||
| 800 | |||
| 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 799 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 802 | tokenizer=tokenizer, | 800 | tokenizer=tokenizer, |
| 803 | embeddings=embeddings, | 801 | embeddings=embeddings, |
| @@ -809,7 +807,11 @@ def main(): | |||
| 809 | 807 | ||
| 810 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 808 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
| 811 | 809 | ||
| 812 | print(f"{i + 1}: {stats}") | 810 | print("") |
| 811 | print(f"============ TI batch {i + 1} ============") | ||
| 812 | print("") | ||
| 813 | print(stats) | ||
| 814 | print("") | ||
| 813 | 815 | ||
| 814 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 816 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
| 815 | 817 | ||
| @@ -868,20 +870,36 @@ def main(): | |||
| 868 | mid_point=args.lr_mid_point, | 870 | mid_point=args.lr_mid_point, |
| 869 | ) | 871 | ) |
| 870 | 872 | ||
| 871 | trainer( | 873 | continue_training = True |
| 872 | project="textual_inversion", | 874 | training_iter = 1 |
| 873 | train_dataloader=datamodule.train_dataloader, | 875 | |
| 874 | val_dataloader=datamodule.val_dataloader, | 876 | while continue_training: |
| 875 | optimizer=optimizer, | 877 | print(f"------------ TI cycle {training_iter} ------------") |
| 876 | lr_scheduler=lr_scheduler, | 878 | print("") |
| 877 | num_train_epochs=num_train_epochs, | 879 | |
| 878 | # -- | 880 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" |
| 879 | group_labels=["emb"], | 881 | sample_output_dir = output_dir / project / "samples" |
| 880 | sample_output_dir=sample_output_dir, | 882 | checkpoint_output_dir = output_dir / project / "checkpoints" |
| 881 | sample_frequency=sample_frequency, | 883 | |
| 882 | placeholder_tokens=placeholder_tokens, | 884 | trainer( |
| 883 | placeholder_token_ids=placeholder_token_ids, | 885 | project=project, |
| 884 | ) | 886 | train_dataloader=datamodule.train_dataloader, |
| 887 | val_dataloader=datamodule.val_dataloader, | ||
| 888 | optimizer=optimizer, | ||
| 889 | lr_scheduler=lr_scheduler, | ||
| 890 | num_train_epochs=num_train_epochs, | ||
| 891 | # -- | ||
| 892 | group_labels=["emb"], | ||
| 893 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 894 | sample_output_dir=sample_output_dir, | ||
| 895 | sample_frequency=sample_frequency, | ||
| 896 | placeholder_tokens=placeholder_tokens, | ||
| 897 | placeholder_token_ids=placeholder_token_ids, | ||
| 898 | ) | ||
| 899 | |||
| 900 | response = input("Run another cycle? [y/n] ") | ||
| 901 | continue_training = response.lower().strip() != "n" | ||
| 902 | training_iter += 1 | ||
| 885 | 903 | ||
| 886 | if not args.sequential: | 904 | if not args.sequential: |
| 887 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 905 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
diff --git a/training/functional.py b/training/functional.py index 7d49782..e14aeea 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -72,7 +72,7 @@ def make_grid(images, rows, cols): | |||
| 72 | return grid | 72 | return grid |
| 73 | 73 | ||
| 74 | 74 | ||
| 75 | def get_models(pretrained_model_name_or_path: str): | 75 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): |
| 76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -81,7 +81,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
| 81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 82 | pretrained_model_name_or_path, subfolder='scheduler') | 82 | pretrained_model_name_or_path, subfolder='scheduler') |
| 83 | 83 | ||
| 84 | embeddings = patch_managed_embeddings(text_encoder) | 84 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) |
| 85 | 85 | ||
| 86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 87 | 87 | ||
