diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-24 10:25:58 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-24 10:25:58 +0100 |
| commit | e09aaedd0e74f2fc6e2a53f233914803c65e127c (patch) | |
| tree | 186a6442cb4de3210837ca459aad81a22a3f37ee | |
| parent | Update (diff) | |
| download | textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.tar.gz textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.tar.bz2 textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.zip | |
Training update
| -rw-r--r-- | train_dreambooth.py | 20 | ||||
| -rw-r--r-- | train_ti.py | 10 | ||||
| -rw-r--r-- | training/ti.py | 2 |
3 files changed, 8 insertions, 24 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c7899a0..51e881a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -24,6 +24,7 @@ from common import load_text_embeddings | |||
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | ||
| 27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args |
| 28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 29 | 30 | ||
| @@ -567,15 +568,8 @@ def main(): | |||
| 567 | print(f"Training entire text encoder.") | 568 | print(f"Training entire text encoder.") |
| 568 | else: | 569 | else: |
| 569 | print(f"Training added text embeddings") | 570 | print(f"Training added text embeddings") |
| 570 | 571 | text_encoder.requires_grad_(False) | |
| 571 | freeze_params(itertools.chain( | 572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 572 | text_encoder.text_model.encoder.parameters(), | ||
| 573 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 574 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 575 | )) | ||
| 576 | |||
| 577 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 578 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 579 | 573 | ||
| 580 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 574 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 581 | 575 | ||
| @@ -603,7 +597,7 @@ def main(): | |||
| 603 | if args.train_text_encoder: | 597 | if args.train_text_encoder: |
| 604 | text_encoder_params_to_optimize = text_encoder.parameters() | 598 | text_encoder_params_to_optimize = text_encoder.parameters() |
| 605 | else: | 599 | else: |
| 606 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | 600 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() |
| 607 | 601 | ||
| 608 | # Initialize the optimizer | 602 | # Initialize the optimizer |
| 609 | optimizer = optimizer_class( | 603 | optimizer = optimizer_class( |
| @@ -914,12 +908,6 @@ def main(): | |||
| 914 | ema_unet.step(unet) | 908 | ema_unet.step(unet) |
| 915 | optimizer.zero_grad(set_to_none=True) | 909 | optimizer.zero_grad(set_to_none=True) |
| 916 | 910 | ||
| 917 | if not args.train_text_encoder: | ||
| 918 | # Let's make sure we don't update any embedding weights besides the newly added token | ||
| 919 | with torch.no_grad(): | ||
| 920 | text_encoder.get_input_embeddings( | ||
| 921 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
| 922 | |||
| 923 | avg_loss.update(loss.detach_(), bsz) | 911 | avg_loss.update(loss.detach_(), bsz) |
| 924 | avg_acc.update(acc.detach_(), bsz) | 912 | avg_acc.update(acc.detach_(), bsz) |
| 925 | 913 | ||
diff --git a/train_ti.py b/train_ti.py index 52bd675..a12b889 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -368,7 +368,6 @@ class Checkpointer(CheckpointerBase): | |||
| 368 | tokenizer, | 368 | tokenizer, |
| 369 | text_encoder, | 369 | text_encoder, |
| 370 | scheduler, | 370 | scheduler, |
| 371 | text_embeddings, | ||
| 372 | placeholder_token, | 371 | placeholder_token, |
| 373 | placeholder_token_id, | 372 | placeholder_token_id, |
| 374 | output_dir: Path, | 373 | output_dir: Path, |
| @@ -394,7 +393,6 @@ class Checkpointer(CheckpointerBase): | |||
| 394 | self.tokenizer = tokenizer | 393 | self.tokenizer = tokenizer |
| 395 | self.text_encoder = text_encoder | 394 | self.text_encoder = text_encoder |
| 396 | self.scheduler = scheduler | 395 | self.scheduler = scheduler |
| 397 | self.text_embeddings = text_embeddings | ||
| 398 | 396 | ||
| 399 | @torch.no_grad() | 397 | @torch.no_grad() |
| 400 | def checkpoint(self, step, postfix): | 398 | def checkpoint(self, step, postfix): |
| @@ -407,7 +405,7 @@ class Checkpointer(CheckpointerBase): | |||
| 407 | 405 | ||
| 408 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 406 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 409 | # Save a checkpoint | 407 | # Save a checkpoint |
| 410 | learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] | 408 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] |
| 411 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 409 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
| 412 | 410 | ||
| 413 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 411 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| @@ -517,8 +515,9 @@ def main(): | |||
| 517 | 515 | ||
| 518 | vae.requires_grad_(False) | 516 | vae.requires_grad_(False) |
| 519 | unet.requires_grad_(False) | 517 | unet.requires_grad_(False) |
| 518 | text_encoder.requires_grad_(False) | ||
| 520 | 519 | ||
| 521 | text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) | 520 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 522 | 521 | ||
| 523 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 522 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 524 | 523 | ||
| @@ -541,7 +540,7 @@ def main(): | |||
| 541 | 540 | ||
| 542 | # Initialize the optimizer | 541 | # Initialize the optimizer |
| 543 | optimizer = optimizer_class( | 542 | optimizer = optimizer_class( |
| 544 | text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings | 543 | text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings |
| 545 | lr=args.learning_rate, | 544 | lr=args.learning_rate, |
| 546 | betas=(args.adam_beta1, args.adam_beta2), | 545 | betas=(args.adam_beta1, args.adam_beta2), |
| 547 | weight_decay=args.adam_weight_decay, | 546 | weight_decay=args.adam_weight_decay, |
| @@ -741,7 +740,6 @@ def main(): | |||
| 741 | tokenizer=tokenizer, | 740 | tokenizer=tokenizer, |
| 742 | text_encoder=text_encoder, | 741 | text_encoder=text_encoder, |
| 743 | scheduler=checkpoint_scheduler, | 742 | scheduler=checkpoint_scheduler, |
| 744 | text_embeddings=text_embeddings, | ||
| 745 | placeholder_token=args.placeholder_token, | 743 | placeholder_token=args.placeholder_token, |
| 746 | placeholder_token_id=placeholder_token_id, | 744 | placeholder_token_id=placeholder_token_id, |
| 747 | output_dir=basepath, | 745 | output_dir=basepath, |
diff --git a/training/ti.py b/training/ti.py index a5e407b..8b2fdd6 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -18,8 +18,6 @@ def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | |||
| 18 | 18 | ||
| 19 | text_encoder.text_model.embeddings = text_embeddings | 19 | text_encoder.text_model.embeddings = text_embeddings |
| 20 | 20 | ||
| 21 | return text_embeddings | ||
| 22 | |||
| 23 | 21 | ||
| 24 | class TrainableEmbeddings(CLIPTextEmbeddings): | 22 | class TrainableEmbeddings(CLIPTextEmbeddings): |
| 25 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 23 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
