summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py49
1 files changed, 1 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py
index 2fd325b..3c9810f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,7 +3,6 @@ import datetime
3import logging 3import logging
4from functools import partial 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6from contextlib import contextmanager, nullcontext
7 6
8import torch 7import torch
9import torch.utils.checkpoint 8import torch.utils.checkpoint
@@ -16,7 +15,6 @@ from slugify import slugify
16 15
17from util import load_config, load_embeddings_from_dir 16from util import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, VlpnDataItem 17from data.csv import VlpnDataModule, VlpnDataItem
19from trainer_old.base import Checkpointer
20from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 18from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
21from training.strategy.ti import textual_inversion_strategy 19from training.strategy.ti import textual_inversion_strategy
22from training.optimization import get_scheduler 20from training.optimization import get_scheduler
@@ -483,51 +481,6 @@ def parse_args():
483 return args 481 return args
484 482
485 483
486class TextualInversionCheckpointer(Checkpointer):
487 def __init__(
488 self,
489 ema_embeddings: EMAModel,
490 placeholder_tokens: list[str],
491 placeholder_token_ids: list[list[int]],
492 *args,
493 **kwargs,
494 ):
495 super().__init__(*args, **kwargs)
496
497 self.ema_embeddings = ema_embeddings
498 self.placeholder_tokens = placeholder_tokens
499 self.placeholder_token_ids = placeholder_token_ids
500
501 @torch.no_grad()
502 def checkpoint(self, step, postfix):
503 print(f"Saving checkpoint for step {step}...")
504
505 checkpoints_path = self.output_dir.joinpath("checkpoints")
506 checkpoints_path.mkdir(parents=True, exist_ok=True)
507
508 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
509
510 ema_context = self.ema_embeddings.apply_temporary(
511 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
512 ) if self.ema_embeddings is not None else nullcontext()
513
514 with ema_context:
515 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
516 text_encoder.text_model.embeddings.save_embed(
517 ids,
518 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
519 )
520
521 @torch.no_grad()
522 def save_samples(self, step):
523 ema_context = self.ema_embeddings.apply_temporary(
524 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
525 ) if self.ema_embeddings is not None else nullcontext()
526
527 with ema_context:
528 super().save_samples(step)
529
530
531def main(): 484def main():
532 args = parse_args() 485 args = parse_args()
533 486
@@ -769,7 +722,7 @@ def main():
769 checkpoint_frequency=args.checkpoint_frequency, 722 checkpoint_frequency=args.checkpoint_frequency,
770 global_step_offset=global_step_offset, 723 global_step_offset=global_step_offset,
771 prior_loss_weight=args.prior_loss_weight, 724 prior_loss_weight=args.prior_loss_weight,
772 **strategy, 725 callbacks=strategy,
773 ) 726 )
774 727
775 728