diff options
author | Volpeon <git@volpeon.ink> | 2022-12-24 10:25:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-24 10:25:58 +0100 |
commit | e09aaedd0e74f2fc6e2a53f233914803c65e127c (patch) | |
tree | 186a6442cb4de3210837ca459aad81a22a3f37ee /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.tar.gz textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.tar.bz2 textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.zip |
Training update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 20 |
1 files changed, 4 insertions, 16 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c7899a0..51e881a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -24,6 +24,7 @@ from common import load_text_embeddings | |||
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | ||
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args |
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
29 | 30 | ||
@@ -567,15 +568,8 @@ def main(): | |||
567 | print(f"Training entire text encoder.") | 568 | print(f"Training entire text encoder.") |
568 | else: | 569 | else: |
569 | print(f"Training added text embeddings") | 570 | print(f"Training added text embeddings") |
570 | 571 | text_encoder.requires_grad_(False) | |
571 | freeze_params(itertools.chain( | 572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
572 | text_encoder.text_model.encoder.parameters(), | ||
573 | text_encoder.text_model.final_layer_norm.parameters(), | ||
574 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
575 | )) | ||
576 | |||
577 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
578 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
579 | 573 | ||
580 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 574 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
581 | 575 | ||
@@ -603,7 +597,7 @@ def main(): | |||
603 | if args.train_text_encoder: | 597 | if args.train_text_encoder: |
604 | text_encoder_params_to_optimize = text_encoder.parameters() | 598 | text_encoder_params_to_optimize = text_encoder.parameters() |
605 | else: | 599 | else: |
606 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | 600 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() |
607 | 601 | ||
608 | # Initialize the optimizer | 602 | # Initialize the optimizer |
609 | optimizer = optimizer_class( | 603 | optimizer = optimizer_class( |
@@ -914,12 +908,6 @@ def main(): | |||
914 | ema_unet.step(unet) | 908 | ema_unet.step(unet) |
915 | optimizer.zero_grad(set_to_none=True) | 909 | optimizer.zero_grad(set_to_none=True) |
916 | 910 | ||
917 | if not args.train_text_encoder: | ||
918 | # Let's make sure we don't update any embedding weights besides the newly added token | ||
919 | with torch.no_grad(): | ||
920 | text_encoder.get_input_embeddings( | ||
921 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
922 | |||
923 | avg_loss.update(loss.detach_(), bsz) | 911 | avg_loss.update(loss.detach_(), bsz) |
924 | avg_acc.update(acc.detach_(), bsz) | 912 | avg_acc.update(acc.detach_(), bsz) |
925 | 913 | ||