summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-22 16:37:47 +0100
committerVolpeon <git@volpeon.ink>2022-12-22 16:37:47 +0100
commitfd691d762820863c5236a189a752ba4f985a961b (patch)
tree1f8db6c6629cdf7df552d7f24e0e7dd16c593b7f /train_ti.py
parentSome LoRA fixes (still broken) (diff)
downloadtextual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.tar.gz
textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.tar.bz2
textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.zip
Improved Textual Inversion: Completely exclude untrained embeddings from training
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 6 insertions, 18 deletions
diff --git a/train_ti.py b/train_ti.py
index 198cf37..bb51dc2 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -24,7 +24,8 @@ from common import load_text_embeddings, load_text_embedding
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
29 30
30logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -512,24 +513,14 @@ def main():
512 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 513 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
513 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) 514 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
514 515
515 original_token_embeds = token_embeds.clone().to(accelerator.device)
516
517 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 516 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
518 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 517 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
519 token_embeds[token_id] = embeddings 518 token_embeds[token_id] = embeddings
520 519
521 index_fixed_tokens = torch.arange(len(tokenizer)) 520 vae.requires_grad_(False)
522 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] 521 unet.requires_grad_(False)
523 522
524 # Freeze vae and unet 523 text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id)
525 freeze_params(vae.parameters())
526 freeze_params(unet.parameters())
527 # Freeze all parameters except for the token embeddings in text encoder
528 freeze_params(itertools.chain(
529 text_encoder.text_model.encoder.parameters(),
530 text_encoder.text_model.final_layer_norm.parameters(),
531 text_encoder.text_model.embeddings.position_embedding.parameters(),
532 ))
533 524
534 prompt_processor = PromptProcessor(tokenizer, text_encoder) 525 prompt_processor = PromptProcessor(tokenizer, text_encoder)
535 526
@@ -843,10 +834,7 @@ def main():
843 lr_scheduler.step() 834 lr_scheduler.step()
844 optimizer.zero_grad(set_to_none=True) 835 optimizer.zero_grad(set_to_none=True)
845 836
846 # Let's make sure we don't update any embedding weights besides the newly added token 837 text_embeddings.save()
847 with torch.no_grad():
848 text_encoder.get_input_embeddings(
849 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
850 838
851 avg_loss.update(loss.detach_(), bsz) 839 avg_loss.update(loss.detach_(), bsz)
852 avg_acc.update(acc.detach_(), bsz) 840 avg_acc.update(acc.detach_(), bsz)