From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- training/functional.py | 100 ++++++++++++------------------------------------- 1 file changed, 23 insertions(+), 77 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 4ca7470..c01595a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -33,6 +33,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): on_prepare: Callable[[float], None] = const() + on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() @@ -267,6 +268,7 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, + with_prior_preservation: bool, prior_loss_weight: float, seed: int, step: int, @@ -322,7 +324,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if batch["with_prior"].all(): + if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -347,7 +349,6 @@ def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - model: torch.nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], @@ -387,28 +388,37 @@ def train_loop( ) global_progress_bar.set_description("Total progress") + model = callbacks.on_model() + on_log = callbacks.on_log + on_train = callbacks.on_train + on_before_optimize = callbacks.on_before_optimize + on_after_optimize = callbacks.on_after_optimize + on_eval = callbacks.on_eval + on_sample = callbacks.on_sample + on_checkpoint = callbacks.on_checkpoint + try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - callbacks.on_sample(global_step + global_step_offset) + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: - callbacks.on_checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() model.train() - with callbacks.on_train(epoch): + with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - callbacks.on_before_optimize(epoch) + on_before_optimize(epoch) optimizer.step() lr_scheduler.step() @@ -419,7 +429,7 @@ def train_loop( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) @@ -433,7 +443,7 @@ def train_loop( "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } - logs.update(callbacks.on_log()) + logs.update(on_log()) accelerator.log(logs, step=global_step) @@ -449,7 +459,7 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(), callbacks.on_eval(): + with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, True) @@ -485,80 +495,16 @@ def train_loop( if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - callbacks.on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - callbacks.on_checkpoint(global_step + global_step_offset, "end") - callbacks.on_sample(global_step + global_step_offset) - accelerator.end_training() + on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step + global_step_offset) except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - callbacks.on_checkpoint(global_step + global_step_offset, "end") - accelerator.end_training() - - -def train( - accelerator: Accelerator, - unet: UNet2DConditionModel, - text_encoder: CLIPTextModel, - vae: AutoencoderKL, - noise_scheduler: DDPMScheduler, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - dtype: torch.dtype, - seed: int, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - num_train_epochs: int = 100, - sample_frequency: int = 20, - checkpoint_frequency: int = 50, - global_step_offset: int = 0, - prior_loss_weight: float = 0, - callbacks: TrainingCallbacks = TrainingCallbacks(), -): - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=dtype) - - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() - - callbacks.on_prepare() - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - prior_loss_weight, - seed, - ) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=sample_frequency, - checkpoint_frequency=checkpoint_frequency, - global_step_offset=global_step_offset, - num_epochs=num_train_epochs, - callbacks=callbacks, - ) - - accelerator.free_memory() + on_checkpoint(global_step + global_step_offset, "end") -- cgit v1.2.3-54-g00ecf