summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py20
-rw-r--r--train_ti.py10
-rw-r--r--training/ti.py2
3 files changed, 8 insertions, 24 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c7899a0..51e881a 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -24,6 +24,7 @@ from common import load_text_embeddings
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 28from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
29 30
@@ -567,15 +568,8 @@ def main():
567 print(f"Training entire text encoder.") 568 print(f"Training entire text encoder.")
568 else: 569 else:
569 print(f"Training added text embeddings") 570 print(f"Training added text embeddings")
570 571 text_encoder.requires_grad_(False)
571 freeze_params(itertools.chain( 572 patch_trainable_embeddings(text_encoder, placeholder_token_id)
572 text_encoder.text_model.encoder.parameters(),
573 text_encoder.text_model.final_layer_norm.parameters(),
574 text_encoder.text_model.embeddings.position_embedding.parameters(),
575 ))
576
577 index_fixed_tokens = torch.arange(len(tokenizer))
578 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
579 573
580 prompt_processor = PromptProcessor(tokenizer, text_encoder) 574 prompt_processor = PromptProcessor(tokenizer, text_encoder)
581 575
@@ -603,7 +597,7 @@ def main():
603 if args.train_text_encoder: 597 if args.train_text_encoder:
604 text_encoder_params_to_optimize = text_encoder.parameters() 598 text_encoder_params_to_optimize = text_encoder.parameters()
605 else: 599 else:
606 text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() 600 text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters()
607 601
608 # Initialize the optimizer 602 # Initialize the optimizer
609 optimizer = optimizer_class( 603 optimizer = optimizer_class(
@@ -914,12 +908,6 @@ def main():
914 ema_unet.step(unet) 908 ema_unet.step(unet)
915 optimizer.zero_grad(set_to_none=True) 909 optimizer.zero_grad(set_to_none=True)
916 910
917 if not args.train_text_encoder:
918 # Let's make sure we don't update any embedding weights besides the newly added token
919 with torch.no_grad():
920 text_encoder.get_input_embeddings(
921 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
922
923 avg_loss.update(loss.detach_(), bsz) 911 avg_loss.update(loss.detach_(), bsz)
924 avg_acc.update(acc.detach_(), bsz) 912 avg_acc.update(acc.detach_(), bsz)
925 913
diff --git a/train_ti.py b/train_ti.py
index 52bd675..a12b889 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -368,7 +368,6 @@ class Checkpointer(CheckpointerBase):
368 tokenizer, 368 tokenizer,
369 text_encoder, 369 text_encoder,
370 scheduler, 370 scheduler,
371 text_embeddings,
372 placeholder_token, 371 placeholder_token,
373 placeholder_token_id, 372 placeholder_token_id,
374 output_dir: Path, 373 output_dir: Path,
@@ -394,7 +393,6 @@ class Checkpointer(CheckpointerBase):
394 self.tokenizer = tokenizer 393 self.tokenizer = tokenizer
395 self.text_encoder = text_encoder 394 self.text_encoder = text_encoder
396 self.scheduler = scheduler 395 self.scheduler = scheduler
397 self.text_embeddings = text_embeddings
398 396
399 @torch.no_grad() 397 @torch.no_grad()
400 def checkpoint(self, step, postfix): 398 def checkpoint(self, step, postfix):
@@ -407,7 +405,7 @@ class Checkpointer(CheckpointerBase):
407 405
408 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 406 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
409 # Save a checkpoint 407 # Save a checkpoint
410 learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] 408 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id]
411 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 409 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
412 410
413 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 411 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
@@ -517,8 +515,9 @@ def main():
517 515
518 vae.requires_grad_(False) 516 vae.requires_grad_(False)
519 unet.requires_grad_(False) 517 unet.requires_grad_(False)
518 text_encoder.requires_grad_(False)
520 519
521 text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) 520 patch_trainable_embeddings(text_encoder, placeholder_token_id)
522 521
523 prompt_processor = PromptProcessor(tokenizer, text_encoder) 522 prompt_processor = PromptProcessor(tokenizer, text_encoder)
524 523
@@ -541,7 +540,7 @@ def main():
541 540
542 # Initialize the optimizer 541 # Initialize the optimizer
543 optimizer = optimizer_class( 542 optimizer = optimizer_class(
544 text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings 543 text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings
545 lr=args.learning_rate, 544 lr=args.learning_rate,
546 betas=(args.adam_beta1, args.adam_beta2), 545 betas=(args.adam_beta1, args.adam_beta2),
547 weight_decay=args.adam_weight_decay, 546 weight_decay=args.adam_weight_decay,
@@ -741,7 +740,6 @@ def main():
741 tokenizer=tokenizer, 740 tokenizer=tokenizer,
742 text_encoder=text_encoder, 741 text_encoder=text_encoder,
743 scheduler=checkpoint_scheduler, 742 scheduler=checkpoint_scheduler,
744 text_embeddings=text_embeddings,
745 placeholder_token=args.placeholder_token, 743 placeholder_token=args.placeholder_token,
746 placeholder_token_id=placeholder_token_id, 744 placeholder_token_id=placeholder_token_id,
747 output_dir=basepath, 745 output_dir=basepath,
diff --git a/training/ti.py b/training/ti.py
index a5e407b..8b2fdd6 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -18,8 +18,6 @@ def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
18 18
19 text_encoder.text_model.embeddings = text_embeddings 19 text_encoder.text_model.embeddings = text_embeddings
20 20
21 return text_embeddings
22
23 21
24class TrainableEmbeddings(CLIPTextEmbeddings): 22class TrainableEmbeddings(CLIPTextEmbeddings):
25 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 23 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):