diff options
author | Volpeon <git@volpeon.ink> | 2022-12-24 10:25:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-24 10:25:58 +0100 |
commit | e09aaedd0e74f2fc6e2a53f233914803c65e127c (patch) | |
tree | 186a6442cb4de3210837ca459aad81a22a3f37ee /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.tar.gz textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.tar.bz2 textual-inversion-diff-e09aaedd0e74f2fc6e2a53f233914803c65e127c.zip |
Training update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 52bd675..a12b889 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -368,7 +368,6 @@ class Checkpointer(CheckpointerBase): | |||
368 | tokenizer, | 368 | tokenizer, |
369 | text_encoder, | 369 | text_encoder, |
370 | scheduler, | 370 | scheduler, |
371 | text_embeddings, | ||
372 | placeholder_token, | 371 | placeholder_token, |
373 | placeholder_token_id, | 372 | placeholder_token_id, |
374 | output_dir: Path, | 373 | output_dir: Path, |
@@ -394,7 +393,6 @@ class Checkpointer(CheckpointerBase): | |||
394 | self.tokenizer = tokenizer | 393 | self.tokenizer = tokenizer |
395 | self.text_encoder = text_encoder | 394 | self.text_encoder = text_encoder |
396 | self.scheduler = scheduler | 395 | self.scheduler = scheduler |
397 | self.text_embeddings = text_embeddings | ||
398 | 396 | ||
399 | @torch.no_grad() | 397 | @torch.no_grad() |
400 | def checkpoint(self, step, postfix): | 398 | def checkpoint(self, step, postfix): |
@@ -407,7 +405,7 @@ class Checkpointer(CheckpointerBase): | |||
407 | 405 | ||
408 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 406 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
409 | # Save a checkpoint | 407 | # Save a checkpoint |
410 | learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] | 408 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] |
411 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 409 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
412 | 410 | ||
413 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 411 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
@@ -517,8 +515,9 @@ def main(): | |||
517 | 515 | ||
518 | vae.requires_grad_(False) | 516 | vae.requires_grad_(False) |
519 | unet.requires_grad_(False) | 517 | unet.requires_grad_(False) |
518 | text_encoder.requires_grad_(False) | ||
520 | 519 | ||
521 | text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) | 520 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
522 | 521 | ||
523 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 522 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
524 | 523 | ||
@@ -541,7 +540,7 @@ def main(): | |||
541 | 540 | ||
542 | # Initialize the optimizer | 541 | # Initialize the optimizer |
543 | optimizer = optimizer_class( | 542 | optimizer = optimizer_class( |
544 | text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings | 543 | text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings |
545 | lr=args.learning_rate, | 544 | lr=args.learning_rate, |
546 | betas=(args.adam_beta1, args.adam_beta2), | 545 | betas=(args.adam_beta1, args.adam_beta2), |
547 | weight_decay=args.adam_weight_decay, | 546 | weight_decay=args.adam_weight_decay, |
@@ -741,7 +740,6 @@ def main(): | |||
741 | tokenizer=tokenizer, | 740 | tokenizer=tokenizer, |
742 | text_encoder=text_encoder, | 741 | text_encoder=text_encoder, |
743 | scheduler=checkpoint_scheduler, | 742 | scheduler=checkpoint_scheduler, |
744 | text_embeddings=text_embeddings, | ||
745 | placeholder_token=args.placeholder_token, | 743 | placeholder_token=args.placeholder_token, |
746 | placeholder_token_id=placeholder_token_id, | 744 | placeholder_token_id=placeholder_token_id, |
747 | output_dir=basepath, | 745 | output_dir=basepath, |