From fc11c86142915d6c3935d28a3321b3ae91b613ef Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:03:01 +0100 Subject: Update --- train_ti.py | 7 ++++--- trainer/base.py | 2 +- trainer/ti.py | 2 +- training/functional.py | 19 +++++++++---------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/train_ti.py b/train_ti.py index deed84c..a4e2dde 100644 --- a/train_ti.py +++ b/train_ti.py @@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer): checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) - @torch.inference_mode() + @torch.no_grad() def save_samples(self, step): ema_context = self.ema_embeddings.apply_temporary( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() @@ -808,7 +808,6 @@ def main(): optimizer=optimizer, lr_scheduler=lr_scheduler, model=text_encoder, - checkpointer=checkpointer, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, @@ -819,7 +818,9 @@ def main(): on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, - on_eval=on_eval + on_eval=on_eval, + on_sample=checkpointer.save_samples, + on_checkpoint=checkpointer.checkpoint, ) diff --git a/trainer/base.py b/trainer/base.py index e700dd6..1f85e71 100644 --- a/trainer/base.py +++ b/trainer/base.py @@ -74,7 +74,7 @@ class Checkpointer(): def checkpoint(self, step: int, postfix: str): pass - @torch.inference_mode() + @torch.no_grad() def save_samples(self, step: int): print(f"Saving samples for step {step}...") diff --git a/trainer/ti.py b/trainer/ti.py index 15cf747..388acd3 100644 --- a/trainer/ti.py +++ b/trainer/ti.py @@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer): checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) - @torch.inference_mode() + @torch.no_grad() def save_samples(self, step): ema_context = self.ema_embeddings.apply_temporary( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() diff --git a/training/functional.py b/training/functional.py index 2d81eca..c100ea2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter -from trainer.base import Checkpointer def const(result=None): @@ -205,7 +204,6 @@ def train_loop( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, model: torch.nn.Module, - checkpointer: Checkpointer, train_dataloader: DataLoader, val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], @@ -217,7 +215,9 @@ def train_loop( on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), on_before_optimize: Callable[[int], None] = const(), on_after_optimize: Callable[[float], None] = const(), - on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) + on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), + on_sample: Callable[[int], None] = const(), + on_checkpoint: Callable[[int, str], None] = const(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) @@ -253,10 +253,10 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset) + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: - checkpointer.checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() @@ -347,19 +347,18 @@ def train_loop( if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - checkpointer.checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset) + on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step + global_step_offset) accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - checkpointer.checkpoint(global_step + global_step_offset, "end") + on_checkpoint(global_step + global_step_offset, "end") accelerator.end_training() - quit() -- cgit v1.2.3-54-g00ecf