diff options
author | Volpeon <git@volpeon.ink> | 2022-12-22 16:37:47 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-22 16:37:47 +0100 |
commit | fd691d762820863c5236a189a752ba4f985a961b (patch) | |
tree | 1f8db6c6629cdf7df552d7f24e0e7dd16c593b7f /train_ti.py | |
parent | Some LoRA fixes (still broken) (diff) | |
download | textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.tar.gz textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.tar.bz2 textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.zip |
Improved Textual Inversion: Completely exclude untrained embeddings from training
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 6 insertions, 18 deletions
diff --git a/train_ti.py b/train_ti.py index 198cf37..bb51dc2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -24,7 +24,8 @@ from common import load_text_embeddings, load_text_embedding | |||
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | ||
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
29 | 30 | ||
30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -512,24 +513,14 @@ def main(): | |||
512 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 513 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
513 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) | 514 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) |
514 | 515 | ||
515 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
516 | |||
517 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 516 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
518 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 517 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
519 | token_embeds[token_id] = embeddings | 518 | token_embeds[token_id] = embeddings |
520 | 519 | ||
521 | index_fixed_tokens = torch.arange(len(tokenizer)) | 520 | vae.requires_grad_(False) |
522 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 521 | unet.requires_grad_(False) |
523 | 522 | ||
524 | # Freeze vae and unet | 523 | text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) |
525 | freeze_params(vae.parameters()) | ||
526 | freeze_params(unet.parameters()) | ||
527 | # Freeze all parameters except for the token embeddings in text encoder | ||
528 | freeze_params(itertools.chain( | ||
529 | text_encoder.text_model.encoder.parameters(), | ||
530 | text_encoder.text_model.final_layer_norm.parameters(), | ||
531 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
532 | )) | ||
533 | 524 | ||
534 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 525 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
535 | 526 | ||
@@ -843,10 +834,7 @@ def main(): | |||
843 | lr_scheduler.step() | 834 | lr_scheduler.step() |
844 | optimizer.zero_grad(set_to_none=True) | 835 | optimizer.zero_grad(set_to_none=True) |
845 | 836 | ||
846 | # Let's make sure we don't update any embedding weights besides the newly added token | 837 | text_embeddings.save() |
847 | with torch.no_grad(): | ||
848 | text_encoder.get_input_embeddings( | ||
849 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
850 | 838 | ||
851 | avg_loss.update(loss.detach_(), bsz) | 839 | avg_loss.update(loss.detach_(), bsz) |
852 | avg_acc.update(acc.detach_(), bsz) | 840 | avg_acc.update(acc.detach_(), bsz) |