From fc11c86142915d6c3935d28a3321b3ae91b613ef Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:03:01 +0100 Subject: Update --- training/functional.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 2d81eca..c100ea2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter -from trainer.base import Checkpointer def const(result=None): @@ -205,7 +204,6 @@ def train_loop( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, model: torch.nn.Module, - checkpointer: Checkpointer, train_dataloader: DataLoader, val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], @@ -217,7 +215,9 @@ def train_loop( on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), on_before_optimize: Callable[[int], None] = const(), on_after_optimize: Callable[[float], None] = const(), - on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) + on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), + on_sample: Callable[[int], None] = const(), + on_checkpoint: Callable[[int, str], None] = const(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) @@ -253,10 +253,10 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset) + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: - checkpointer.checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() @@ -347,19 +347,18 @@ def train_loop( if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - checkpointer.checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset) + on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step + global_step_offset) accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - checkpointer.checkpoint(global_step + global_step_offset, "end") + on_checkpoint(global_step + global_step_offset, "end") accelerator.end_training() - quit() -- cgit v1.2.3-54-g00ecf