summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-09 11:29:31 +0200
committerVolpeon <git@volpeon.ink>2023-04-09 11:29:31 +0200
commitba9fd1a10746d85d2502c8a79ac49db63d346b04 (patch)
tree568bf65a0a4dcea2c34de4006b5761d0d6564307
parentFix (diff)
downloadtextual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.gz
textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.bz2
textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.zip
Update
-rw-r--r--infer.py1
-rw-r--r--models/clip/embeddings.py7
-rw-r--r--models/sparse.py13
-rw-r--r--train_dreambooth.py1
-rw-r--r--train_lora.py140
-rw-r--r--train_ti.py66
-rw-r--r--training/functional.py4
7 files changed, 72 insertions, 160 deletions
diff --git a/infer.py b/infer.py
index 8fdf63d..4648c0a 100644
--- a/infer.py
+++ b/infer.py
@@ -235,6 +235,7 @@ def load_embeddings(pipeline, embeddings_dir):
235 pipeline.text_encoder.text_model.embeddings, 235 pipeline.text_encoder.text_model.embeddings,
236 Path(embeddings_dir) 236 Path(embeddings_dir)
237 ) 237 )
238 pipeline.text_encoder.text_model.embeddings.persist()
238 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 239 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
239 240
240 241
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6fda33c..dc4708a 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
37 37
38 38
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): 40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0):
41 super().__init__(config) 41 super().__init__(config)
42 42
43 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
@@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
46 46
47 self.token_override_embedding = PseudoSparseEmbedding( 47 self.token_override_embedding = PseudoSparseEmbedding(
48 self.token_embedding.embedding_dim, 48 self.token_embedding.embedding_dim,
49 dropout_p=dropout_p,
49 device=self.token_embedding.weight.device, 50 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype, 51 dtype=self.token_embedding.weight.dtype,
51 ) 52 )
@@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
134 return embeddings 135 return embeddings
135 136
136 137
137def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: 138def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings:
138 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 139 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p)
139 text_encoder.text_model.embeddings = text_embeddings 140 text_encoder.text_model.embeddings = text_embeddings
140 return text_embeddings 141 return text_embeddings
diff --git a/models/sparse.py b/models/sparse.py
index d706db5..bcb2897 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -5,22 +5,29 @@ import torch.nn as nn
5 5
6 6
7class PseudoSparseEmbedding(nn.Module): 7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): 8 def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32):
9 super().__init__() 9 super().__init__()
10 10
11 self.embedding_dim = embedding_dim 11 self.embedding_dim = embedding_dim
12 self.dtype = dtype 12 self.dtype = dtype
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14
15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p)
17 else:
18 self.dropout = lambda x: x
19
14 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) 20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
15 21
16 def forward(self, input_ids: torch.LongTensor): 22 def forward(self, input_ids: torch.LongTensor):
17 ids = self.mapping[input_ids.to(self.mapping.device)] 23 input_ids = input_ids.to(self.mapping.device)
24 ids = self.mapping[input_ids]
18 mask = ~(ids == -1) 25 mask = ~(ids == -1)
19 26
20 if torch.all(~mask): 27 if torch.all(~mask):
21 embs = None 28 embs = None
22 else: 29 else:
23 embs = torch.stack([self.params[id] for id in ids[mask]]) 30 embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]]))
24 31
25 return embs, mask 32 return embs, mask
26 33
diff --git a/train_dreambooth.py b/train_dreambooth.py
index f4d4cbb..2aca1e7 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -513,6 +513,7 @@ def main():
513 raise ValueError("--embeddings_dir must point to an existing directory") 513 raise ValueError("--embeddings_dir must point to an existing directory")
514 514
515 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 515 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
516 embeddings.persist()
516 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 517 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
517 518
518 if args.scale_lr: 519 if args.scale_lr:
diff --git a/train_lora.py b/train_lora.py
index 8dbe45b..6e21634 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -159,12 +159,6 @@ def parse_args():
159 help="Tag dropout probability.", 159 help="Tag dropout probability.",
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
162 "--pti_tag_dropout",
163 type=float,
164 default=0,
165 help="Tag dropout probability.",
166 )
167 parser.add_argument(
168 "--no_tag_shuffle", 162 "--no_tag_shuffle",
169 action="store_true", 163 action="store_true",
170 help="Shuffle tags.", 164 help="Shuffle tags.",
@@ -236,28 +230,12 @@ def parse_args():
236 default=2000 230 default=2000
237 ) 231 )
238 parser.add_argument( 232 parser.add_argument(
239 "--num_pti_epochs",
240 type=int,
241 default=None
242 )
243 parser.add_argument(
244 "--num_pti_steps",
245 type=int,
246 default=500
247 )
248 parser.add_argument(
249 "--gradient_accumulation_steps", 233 "--gradient_accumulation_steps",
250 type=int, 234 type=int,
251 default=1, 235 default=1,
252 help="Number of updates steps to accumulate before performing a backward/update pass.", 236 help="Number of updates steps to accumulate before performing a backward/update pass.",
253 ) 237 )
254 parser.add_argument( 238 parser.add_argument(
255 "--pti_gradient_accumulation_steps",
256 type=int,
257 default=1,
258 help="Number of updates steps to accumulate before performing a backward/update pass.",
259 )
260 parser.add_argument(
261 "--lora_r", 239 "--lora_r",
262 type=int, 240 type=int,
263 default=8, 241 default=8,
@@ -323,12 +301,6 @@ def parse_args():
323 help="Initial learning rate (after the potential warmup period) to use.", 301 help="Initial learning rate (after the potential warmup period) to use.",
324 ) 302 )
325 parser.add_argument( 303 parser.add_argument(
326 "--learning_rate_pti",
327 type=float,
328 default=1e-4,
329 help="Initial learning rate (after the potential warmup period) to use.",
330 )
331 parser.add_argument(
332 "--learning_rate_emb", 304 "--learning_rate_emb",
333 type=float, 305 type=float,
334 default=1e-5, 306 default=1e-5,
@@ -467,12 +439,6 @@ def parse_args():
467 help="How often to save a checkpoint and sample image", 439 help="How often to save a checkpoint and sample image",
468 ) 440 )
469 parser.add_argument( 441 parser.add_argument(
470 "--pti_sample_frequency",
471 type=int,
472 default=1,
473 help="How often to save a checkpoint and sample image",
474 )
475 parser.add_argument(
476 "--sample_image_size", 442 "--sample_image_size",
477 type=int, 443 type=int,
478 default=768, 444 default=768,
@@ -509,12 +475,6 @@ def parse_args():
509 help="Batch size (per device) for the training dataloader." 475 help="Batch size (per device) for the training dataloader."
510 ) 476 )
511 parser.add_argument( 477 parser.add_argument(
512 "--pti_batch_size",
513 type=int,
514 default=1,
515 help="Batch size (per device) for the training dataloader."
516 )
517 parser.add_argument(
518 "--sample_steps", 478 "--sample_steps",
519 type=int, 479 type=int,
520 default=10, 480 default=10,
@@ -527,6 +487,12 @@ def parse_args():
527 help="The weight of prior preservation loss." 487 help="The weight of prior preservation loss."
528 ) 488 )
529 parser.add_argument( 489 parser.add_argument(
490 "--emb_dropout",
491 type=float,
492 default=0,
493 help="Embedding dropout probability.",
494 )
495 parser.add_argument(
530 "--use_emb_decay", 496 "--use_emb_decay",
531 action="store_true", 497 action="store_true",
532 help="Whether to use embedding decay." 498 help="Whether to use embedding decay."
@@ -674,7 +640,7 @@ def main():
674 save_args(output_dir, args) 640 save_args(output_dir, args)
675 641
676 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 642 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
677 args.pretrained_model_name_or_path) 643 args.pretrained_model_name_or_path, args.emb_dropout)
678 644
679 unet_config = LoraConfig( 645 unet_config = LoraConfig(
680 r=args.lora_r, 646 r=args.lora_r,
@@ -720,6 +686,7 @@ def main():
720 raise ValueError("--embeddings_dir must point to an existing directory") 686 raise ValueError("--embeddings_dir must point to an existing directory")
721 687
722 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 688 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
689 embeddings.persist()
723 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 690 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
724 691
725 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 692 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
@@ -744,19 +711,14 @@ def main():
744 args.learning_rate_text * args.gradient_accumulation_steps * 711 args.learning_rate_text * args.gradient_accumulation_steps *
745 args.train_batch_size * accelerator.num_processes 712 args.train_batch_size * accelerator.num_processes
746 ) 713 )
747 args.learning_rate_pti = (
748 args.learning_rate_pti * args.pti_gradient_accumulation_steps *
749 args.pti_batch_size * accelerator.num_processes
750 )
751 args.learning_rate_emb = ( 714 args.learning_rate_emb = (
752 args.learning_rate_emb * args.pti_gradient_accumulation_steps * 715 args.learning_rate_emb * args.gradient_accumulation_steps *
753 args.pti_batch_size * accelerator.num_processes 716 args.train_batch_size * accelerator.num_processes
754 ) 717 )
755 718
756 if args.find_lr: 719 if args.find_lr:
757 args.learning_rate_unet = 1e-6 720 args.learning_rate_unet = 1e-6
758 args.learning_rate_text = 1e-6 721 args.learning_rate_text = 1e-6
759 args.learning_rate_pti = 1e-6
760 args.learning_rate_emb = 1e-6 722 args.learning_rate_emb = 1e-6
761 args.lr_scheduler = "exponential_growth" 723 args.lr_scheduler = "exponential_growth"
762 724
@@ -817,7 +779,6 @@ def main():
817 args.lr_min_lr = args.learning_rate_unet 779 args.lr_min_lr = args.learning_rate_unet
818 args.learning_rate_unet = None 780 args.learning_rate_unet = None
819 args.learning_rate_text = None 781 args.learning_rate_text = None
820 args.learning_rate_pti = None
821 args.learning_rate_emb = None 782 args.learning_rate_emb = None
822 elif args.optimizer == 'dadam': 783 elif args.optimizer == 'dadam':
823 try: 784 try:
@@ -836,7 +797,6 @@ def main():
836 797
837 args.learning_rate_unet = 1.0 798 args.learning_rate_unet = 1.0
838 args.learning_rate_text = 1.0 799 args.learning_rate_text = 1.0
839 args.learning_rate_pti = 1.0
840 args.learning_rate_emb = 1.0 800 args.learning_rate_emb = 1.0
841 elif args.optimizer == 'dadan': 801 elif args.optimizer == 'dadan':
842 try: 802 try:
@@ -853,7 +813,6 @@ def main():
853 813
854 args.learning_rate_unet = 1.0 814 args.learning_rate_unet = 1.0
855 args.learning_rate_text = 1.0 815 args.learning_rate_text = 1.0
856 args.learning_rate_pti = 1.0
857 args.learning_rate_emb = 1.0 816 args.learning_rate_emb = 1.0
858 else: 817 else:
859 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 818 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
@@ -920,80 +879,6 @@ def main():
920 mid_point=args.lr_mid_point, 879 mid_point=args.lr_mid_point,
921 ) 880 )
922 881
923 # PTI
924 # --------------------------------------------------------------------------------
925
926 if len(args.placeholder_tokens) != 0:
927 pti_datamodule = create_datamodule(
928 batch_size=args.pti_batch_size,
929 dropout=args.pti_tag_dropout,
930 filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections),
931 )
932 pti_datamodule.setup()
933
934 num_pti_epochs = args.num_pti_epochs
935 pti_sample_frequency = args.pti_sample_frequency
936 if num_pti_epochs is None:
937 num_pti_epochs = math.ceil(
938 args.num_pti_steps / len(pti_datamodule.train_dataset)
939 ) * args.pti_gradient_accumulation_steps
940 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps))
941
942 if num_pti_epochs > 0:
943 pti_optimizer = create_optimizer(
944 [
945 {
946 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
947 "lr": args.learning_rate_pti,
948 "weight_decay": 0,
949 },
950 ]
951 )
952
953 pti_lr_scheduler = create_lr_scheduler(
954 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
955 optimizer=pti_optimizer,
956 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
957 train_epochs=num_pti_epochs,
958 )
959
960 continue_training = True
961 training_iter = 1
962
963 while continue_training:
964 print("")
965 print(f"============ PTI cycle {training_iter} ============")
966 print("")
967
968 pti_project = f"pti_{training_iter}"
969 pti_output_dir = output_dir / pti_project
970 pti_checkpoint_output_dir = pti_output_dir / "model"
971 pti_sample_output_dir = pti_output_dir / "samples"
972
973 trainer(
974 strategy=lora_strategy,
975 pti_mode=True,
976 project=pti_project,
977 train_dataloader=pti_datamodule.train_dataloader,
978 val_dataloader=pti_datamodule.val_dataloader,
979 optimizer=pti_optimizer,
980 lr_scheduler=pti_lr_scheduler,
981 num_train_epochs=num_pti_epochs,
982 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
983 # --
984 group_labels=["emb"],
985 sample_output_dir=pti_sample_output_dir,
986 checkpoint_output_dir=pti_checkpoint_output_dir,
987 sample_frequency=pti_sample_frequency,
988 )
989
990 response = input("Run another cycle? [y/n] ")
991 continue_training = response.lower().strip() != "n"
992 training_iter += 1
993
994 if not args.train_emb:
995 embeddings.persist()
996
997 # LORA 882 # LORA
998 # -------------------------------------------------------------------------------- 883 # --------------------------------------------------------------------------------
999 884
@@ -1062,9 +947,8 @@ def main():
1062 print("") 947 print("")
1063 948
1064 lora_project = f"lora_{training_iter}" 949 lora_project = f"lora_{training_iter}"
1065 lora_output_dir = output_dir / lora_project 950 lora_checkpoint_output_dir = output_dir / lora_project / "model"
1066 lora_checkpoint_output_dir = lora_output_dir / "model" 951 lora_sample_output_dir = output_dir / lora_project / "samples"
1067 lora_sample_output_dir = lora_output_dir / "samples"
1068 952
1069 trainer( 953 trainer(
1070 strategy=lora_strategy, 954 strategy=lora_strategy,
diff --git a/train_ti.py b/train_ti.py
index daf8bc5..2d51800 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -458,6 +458,12 @@ def parse_args():
458 help="The weight of prior preservation loss." 458 help="The weight of prior preservation loss."
459 ) 459 )
460 parser.add_argument( 460 parser.add_argument(
461 "--emb_dropout",
462 type=float,
463 default=0,
464 help="Embedding dropout probability.",
465 )
466 parser.add_argument(
461 "--use_emb_decay", 467 "--use_emb_decay",
462 action="store_true", 468 action="store_true",
463 help="Whether to use embedding decay." 469 help="Whether to use embedding decay."
@@ -624,7 +630,7 @@ def main():
624 save_args(output_dir, args) 630 save_args(output_dir, args)
625 631
626 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 632 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
627 args.pretrained_model_name_or_path) 633 args.pretrained_model_name_or_path, args.emb_dropout)
628 634
629 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 635 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
630 tokenizer.set_dropout(args.vector_dropout) 636 tokenizer.set_dropout(args.vector_dropout)
@@ -755,8 +761,6 @@ def main():
755 else: 761 else:
756 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 762 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
757 763
758 checkpoint_output_dir = output_dir / "checkpoints"
759
760 trainer = partial( 764 trainer = partial(
761 train, 765 train,
762 accelerator=accelerator, 766 accelerator=accelerator,
@@ -777,7 +781,6 @@ def main():
777 global_step_offset=global_step_offset, 781 global_step_offset=global_step_offset,
778 offset_noise_strength=args.offset_noise_strength, 782 offset_noise_strength=args.offset_noise_strength,
779 # -- 783 # --
780 checkpoint_output_dir=checkpoint_output_dir,
781 use_emb_decay=args.use_emb_decay, 784 use_emb_decay=args.use_emb_decay,
782 emb_decay_target=args.emb_decay_target, 785 emb_decay_target=args.emb_decay_target,
783 emb_decay=args.emb_decay, 786 emb_decay=args.emb_decay,
@@ -793,11 +796,6 @@ def main():
793 ) 796 )
794 797
795 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 798 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
796 if len(placeholder_tokens) == 1:
797 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
798 else:
799 sample_output_dir = output_dir / "samples"
800
801 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 799 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
802 tokenizer=tokenizer, 800 tokenizer=tokenizer,
803 embeddings=embeddings, 801 embeddings=embeddings,
@@ -809,7 +807,11 @@ def main():
809 807
810 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) 808 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))
811 809
812 print(f"{i + 1}: {stats}") 810 print("")
811 print(f"============ TI batch {i + 1} ============")
812 print("")
813 print(stats)
814 print("")
813 815
814 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 816 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
815 817
@@ -868,20 +870,36 @@ def main():
868 mid_point=args.lr_mid_point, 870 mid_point=args.lr_mid_point,
869 ) 871 )
870 872
871 trainer( 873 continue_training = True
872 project="textual_inversion", 874 training_iter = 1
873 train_dataloader=datamodule.train_dataloader, 875
874 val_dataloader=datamodule.val_dataloader, 876 while continue_training:
875 optimizer=optimizer, 877 print(f"------------ TI cycle {training_iter} ------------")
876 lr_scheduler=lr_scheduler, 878 print("")
877 num_train_epochs=num_train_epochs, 879
878 # -- 880 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}"
879 group_labels=["emb"], 881 sample_output_dir = output_dir / project / "samples"
880 sample_output_dir=sample_output_dir, 882 checkpoint_output_dir = output_dir / project / "checkpoints"
881 sample_frequency=sample_frequency, 883
882 placeholder_tokens=placeholder_tokens, 884 trainer(
883 placeholder_token_ids=placeholder_token_ids, 885 project=project,
884 ) 886 train_dataloader=datamodule.train_dataloader,
887 val_dataloader=datamodule.val_dataloader,
888 optimizer=optimizer,
889 lr_scheduler=lr_scheduler,
890 num_train_epochs=num_train_epochs,
891 # --
892 group_labels=["emb"],
893 checkpoint_output_dir=checkpoint_output_dir,
894 sample_output_dir=sample_output_dir,
895 sample_frequency=sample_frequency,
896 placeholder_tokens=placeholder_tokens,
897 placeholder_token_ids=placeholder_token_ids,
898 )
899
900 response = input("Run another cycle? [y/n] ")
901 continue_training = response.lower().strip() != "n"
902 training_iter += 1
885 903
886 if not args.sequential: 904 if not args.sequential:
887 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 905 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
diff --git a/training/functional.py b/training/functional.py
index 7d49782..e14aeea 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -72,7 +72,7 @@ def make_grid(images, rows, cols):
72 return grid 72 return grid
73 73
74 74
75def get_models(pretrained_model_name_or_path: str): 75def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0):
76 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 76 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
77 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 77 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
78 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 78 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -81,7 +81,7 @@ def get_models(pretrained_model_name_or_path: str):
81 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 81 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
82 pretrained_model_name_or_path, subfolder='scheduler') 82 pretrained_model_name_or_path, subfolder='scheduler')
83 83
84 embeddings = patch_managed_embeddings(text_encoder) 84 embeddings = patch_managed_embeddings(text_encoder, emb_dropout)
85 85
86 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 86 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
87 87