diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 10:38:49 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 10:38:49 +0100 |
commit | 37baa3aa254af721728aa33befdc383858cb8ea2 (patch) | |
tree | ebf64291e052280eea661f8a8d96c486dd5c1cf6 /train_ti.py | |
parent | Added functional TI strategy (diff) | |
download | textual-inversion-diff-37baa3aa254af721728aa33befdc383858cb8ea2.tar.gz textual-inversion-diff-37baa3aa254af721728aa33befdc383858cb8ea2.tar.bz2 textual-inversion-diff-37baa3aa254af721728aa33befdc383858cb8ea2.zip |
Removed unused code, put training callbacks in dataclass
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 49 |
1 files changed, 1 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py index 2fd325b..3c9810f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -3,7 +3,6 @@ import datetime | |||
3 | import logging | 3 | import logging |
4 | from functools import partial | 4 | from functools import partial |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from contextlib import contextmanager, nullcontext | ||
7 | 6 | ||
8 | import torch | 7 | import torch |
9 | import torch.utils.checkpoint | 8 | import torch.utils.checkpoint |
@@ -16,7 +15,6 @@ from slugify import slugify | |||
16 | 15 | ||
17 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir |
18 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from data.csv import VlpnDataModule, VlpnDataItem |
19 | from trainer_old.base import Checkpointer | ||
20 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 18 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
21 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
22 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
@@ -483,51 +481,6 @@ def parse_args(): | |||
483 | return args | 481 | return args |
484 | 482 | ||
485 | 483 | ||
486 | class TextualInversionCheckpointer(Checkpointer): | ||
487 | def __init__( | ||
488 | self, | ||
489 | ema_embeddings: EMAModel, | ||
490 | placeholder_tokens: list[str], | ||
491 | placeholder_token_ids: list[list[int]], | ||
492 | *args, | ||
493 | **kwargs, | ||
494 | ): | ||
495 | super().__init__(*args, **kwargs) | ||
496 | |||
497 | self.ema_embeddings = ema_embeddings | ||
498 | self.placeholder_tokens = placeholder_tokens | ||
499 | self.placeholder_token_ids = placeholder_token_ids | ||
500 | |||
501 | @torch.no_grad() | ||
502 | def checkpoint(self, step, postfix): | ||
503 | print(f"Saving checkpoint for step {step}...") | ||
504 | |||
505 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
506 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
507 | |||
508 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
509 | |||
510 | ema_context = self.ema_embeddings.apply_temporary( | ||
511 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
512 | ) if self.ema_embeddings is not None else nullcontext() | ||
513 | |||
514 | with ema_context: | ||
515 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): | ||
516 | text_encoder.text_model.embeddings.save_embed( | ||
517 | ids, | ||
518 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | ||
519 | ) | ||
520 | |||
521 | @torch.no_grad() | ||
522 | def save_samples(self, step): | ||
523 | ema_context = self.ema_embeddings.apply_temporary( | ||
524 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
525 | ) if self.ema_embeddings is not None else nullcontext() | ||
526 | |||
527 | with ema_context: | ||
528 | super().save_samples(step) | ||
529 | |||
530 | |||
531 | def main(): | 484 | def main(): |
532 | args = parse_args() | 485 | args = parse_args() |
533 | 486 | ||
@@ -769,7 +722,7 @@ def main(): | |||
769 | checkpoint_frequency=args.checkpoint_frequency, | 722 | checkpoint_frequency=args.checkpoint_frequency, |
770 | global_step_offset=global_step_offset, | 723 | global_step_offset=global_step_offset, |
771 | prior_loss_weight=args.prior_loss_weight, | 724 | prior_loss_weight=args.prior_loss_weight, |
772 | **strategy, | 725 | callbacks=strategy, |
773 | ) | 726 | ) |
774 | 727 | ||
775 | 728 | ||