diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 49 |
1 files changed, 1 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py index 2fd325b..3c9810f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -3,7 +3,6 @@ import datetime | |||
| 3 | import logging | 3 | import logging |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from contextlib import contextmanager, nullcontext | ||
| 7 | 6 | ||
| 8 | import torch | 7 | import torch |
| 9 | import torch.utils.checkpoint | 8 | import torch.utils.checkpoint |
| @@ -16,7 +15,6 @@ from slugify import slugify | |||
| 16 | 15 | ||
| 17 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir |
| 18 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from data.csv import VlpnDataModule, VlpnDataItem |
| 19 | from trainer_old.base import Checkpointer | ||
| 20 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 18 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
| 21 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
| 22 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
| @@ -483,51 +481,6 @@ def parse_args(): | |||
| 483 | return args | 481 | return args |
| 484 | 482 | ||
| 485 | 483 | ||
| 486 | class TextualInversionCheckpointer(Checkpointer): | ||
| 487 | def __init__( | ||
| 488 | self, | ||
| 489 | ema_embeddings: EMAModel, | ||
| 490 | placeholder_tokens: list[str], | ||
| 491 | placeholder_token_ids: list[list[int]], | ||
| 492 | *args, | ||
| 493 | **kwargs, | ||
| 494 | ): | ||
| 495 | super().__init__(*args, **kwargs) | ||
| 496 | |||
| 497 | self.ema_embeddings = ema_embeddings | ||
| 498 | self.placeholder_tokens = placeholder_tokens | ||
| 499 | self.placeholder_token_ids = placeholder_token_ids | ||
| 500 | |||
| 501 | @torch.no_grad() | ||
| 502 | def checkpoint(self, step, postfix): | ||
| 503 | print(f"Saving checkpoint for step {step}...") | ||
| 504 | |||
| 505 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 506 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 507 | |||
| 508 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 509 | |||
| 510 | ema_context = self.ema_embeddings.apply_temporary( | ||
| 511 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 512 | ) if self.ema_embeddings is not None else nullcontext() | ||
| 513 | |||
| 514 | with ema_context: | ||
| 515 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): | ||
| 516 | text_encoder.text_model.embeddings.save_embed( | ||
| 517 | ids, | ||
| 518 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | ||
| 519 | ) | ||
| 520 | |||
| 521 | @torch.no_grad() | ||
| 522 | def save_samples(self, step): | ||
| 523 | ema_context = self.ema_embeddings.apply_temporary( | ||
| 524 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 525 | ) if self.ema_embeddings is not None else nullcontext() | ||
| 526 | |||
| 527 | with ema_context: | ||
| 528 | super().save_samples(step) | ||
| 529 | |||
| 530 | |||
| 531 | def main(): | 484 | def main(): |
| 532 | args = parse_args() | 485 | args = parse_args() |
| 533 | 486 | ||
| @@ -769,7 +722,7 @@ def main(): | |||
| 769 | checkpoint_frequency=args.checkpoint_frequency, | 722 | checkpoint_frequency=args.checkpoint_frequency, |
| 770 | global_step_offset=global_step_offset, | 723 | global_step_offset=global_step_offset, |
| 771 | prior_loss_weight=args.prior_loss_weight, | 724 | prior_loss_weight=args.prior_loss_weight, |
| 772 | **strategy, | 725 | callbacks=strategy, |
| 773 | ) | 726 | ) |
| 774 | 727 | ||
| 775 | 728 | ||
