From 37baa3aa254af721728aa33befdc383858cb8ea2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 10:38:49 +0100 Subject: Removed unused code, put training callbacks in dataclass --- train_ti.py | 49 +------------------------------------------------ 1 file changed, 1 insertion(+), 48 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 2fd325b..3c9810f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -3,7 +3,6 @@ import datetime import logging from functools import partial from pathlib import Path -from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint @@ -16,7 +15,6 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem -from trainer_old.base import Checkpointer from training.functional import train, generate_class_images, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler @@ -483,51 +481,6 @@ def parse_args(): return args -class TextualInversionCheckpointer(Checkpointer): - def __init__( - self, - ema_embeddings: EMAModel, - placeholder_tokens: list[str], - placeholder_token_ids: list[list[int]], - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.ema_embeddings = ema_embeddings - self.placeholder_tokens = placeholder_tokens - self.placeholder_token_ids = placeholder_token_ids - - @torch.no_grad() - def checkpoint(self, step, postfix): - print(f"Saving checkpoint for step {step}...") - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - @torch.no_grad() - def save_samples(self, step): - ema_context = self.ema_embeddings.apply_temporary( - self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - super().save_samples(step) - - def main(): args = parse_args() @@ -769,7 +722,7 @@ def main(): checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, prior_loss_weight=args.prior_loss_weight, - **strategy, + callbacks=strategy, ) -- cgit v1.2.3-54-g00ecf