summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py9
-rw-r--r--train_ti.py10
-rw-r--r--training/ti.py15
3 files changed, 21 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 51e881a..8cb6414 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -568,9 +568,16 @@ def main():
568 print(f"Training entire text encoder.") 568 print(f"Training entire text encoder.")
569 else: 569 else:
570 print(f"Training added text embeddings") 570 print(f"Training added text embeddings")
571 text_encoder.requires_grad_(False) 571
572 patch_trainable_embeddings(text_encoder, placeholder_token_id) 572 patch_trainable_embeddings(text_encoder, placeholder_token_id)
573 573
574 freeze_params(itertools.chain(
575 text_encoder.text_model.encoder.parameters(),
576 text_encoder.text_model.final_layer_norm.parameters(),
577 text_encoder.text_model.embeddings.position_embedding.parameters(),
578 text_encoder.text_model.embeddings.token_embedding.parameters(),
579 ))
580
574 prompt_processor = PromptProcessor(tokenizer, text_encoder) 581 prompt_processor = PromptProcessor(tokenizer, text_encoder)
575 582
576 if args.scale_lr: 583 if args.scale_lr:
diff --git a/train_ti.py b/train_ti.py
index a12b889..5f37d54 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -515,10 +515,16 @@ def main():
515 515
516 vae.requires_grad_(False) 516 vae.requires_grad_(False)
517 unet.requires_grad_(False) 517 unet.requires_grad_(False)
518 text_encoder.requires_grad_(False)
519 518
520 patch_trainable_embeddings(text_encoder, placeholder_token_id) 519 patch_trainable_embeddings(text_encoder, placeholder_token_id)
521 520
521 freeze_params(itertools.chain(
522 text_encoder.text_model.encoder.parameters(),
523 text_encoder.text_model.final_layer_norm.parameters(),
524 text_encoder.text_model.embeddings.position_embedding.parameters(),
525 text_encoder.text_model.embeddings.token_embedding.parameters(),
526 ))
527
522 prompt_processor = PromptProcessor(tokenizer, text_encoder) 528 prompt_processor = PromptProcessor(tokenizer, text_encoder)
523 529
524 if args.scale_lr: 530 if args.scale_lr:
diff --git a/training/ti.py b/training/ti.py
index 8b2fdd6..dc33e5e 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -8,26 +8,21 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
8 8
9 9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): 10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) 11 text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids)
12
13 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
14 text_embeddings.token_embedding.weight.requires_grad = False
15
16 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
17 text_embeddings.position_embedding.weight.requires_grad = False
18
19 text_encoder.text_model.embeddings = text_embeddings 12 text_encoder.text_model.embeddings = text_embeddings
20 13
21 14
22class TrainableEmbeddings(CLIPTextEmbeddings): 15class TrainableEmbeddings(CLIPTextEmbeddings):
23 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]):
24 super().__init__(config) 17 super().__init__(config)
25 18
26 self.train_indices = torch.tensor(new_ids) 19 self.train_indices = torch.tensor(new_ids)
27 20
28 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 21 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
22
23 self.token_embedding = embeddings.token_embedding
24 self.position_embedding = embeddings.position_embedding
29 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() 25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone()
30 self.trainable_embedding.weight.requires_grad = True
31 26
32 def forward( 27 def forward(
33 self, 28 self,