From 89afcfda3f824cc44221e877182348f9b09687d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:31:55 +0100 Subject: Handle empty validation dataset --- training/functional.py | 101 ++++++++++++++++++++++------------------ training/strategy/dreambooth.py | 2 +- training/strategy/ti.py | 2 +- 3 files changed, 58 insertions(+), 47 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 1b6162b..c6b4dc3 100644 --- a/training/functional.py +++ b/training/functional.py @@ -73,7 +73,7 @@ def save_samples( vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], dtype: torch.dtype, output_dir: Path, seed: int, @@ -111,11 +111,13 @@ def save_samples( generator = torch.Generator(device=accelerator.device).manual_seed(seed) - for pool, data, gen in [ - ("stable", val_dataloader, generator), - ("val", val_dataloader, None), - ("train", train_dataloader, None) - ]: + datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] + + if val_dataloader is not None: + datasets.append(("stable", val_dataloader, generator)) + datasets.append(("val", val_dataloader, None)) + + for pool, data, gen in datasets: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) @@ -328,7 +330,7 @@ def train_loop( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], sample_frequency: int = 10, checkpoint_frequency: int = 50, @@ -337,7 +339,7 @@ def train_loop( callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) - num_val_steps_per_epoch = len(val_dataloader) + num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -350,6 +352,7 @@ def train_loop( avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() + max_acc = 0.0 max_acc_val = 0.0 local_progress_bar = tqdm( @@ -432,49 +435,57 @@ def train_loop( accelerator.wait_for_everyone() - model.eval() - - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() - - with torch.inference_mode(), on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loss_step(step, batch, True) - - loss = loss.detach_() - acc = acc.detach_() - - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + if val_dataloader is not None: + model.eval() - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() - local_progress_bar.update(1) - global_progress_bar.update(1) + with torch.inference_mode(), on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + loss = loss.detach_() + acc = acc.detach_() - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - accelerator.log(logs, step=global_step) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.clear() - global_progress_bar.clear() + local_progress_bar.update(1) + global_progress_bar.update(1) - if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") - max_acc_val = avg_acc_val.avg.item() + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) + + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) + + local_progress_bar.clear() + global_progress_bar.clear() + + if accelerator.is_main_process: + if avg_acc_val.avg.item() > max_acc_val: + accelerator.print( + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + on_checkpoint(global_step + global_step_offset, "milestone") + max_acc_val = avg_acc_val.avg.item() + else: + if accelerator.is_main_process: + if avg_acc.avg.item() > max_acc: + accelerator.print( + f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") + on_checkpoint(global_step + global_step_offset, "milestone") + max_acc = avg_acc.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -499,7 +510,7 @@ def train( seed: int, project: str, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, callbacks_fn: Callable[..., TrainingCallbacks], diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 6e7ebe2..aeaa828 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -26,7 +26,7 @@ def dreambooth_strategy( vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, train_text_encoder_epochs: int, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 753dce0..568f9eb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -26,7 +26,7 @@ def textual_inversion_strategy( vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, placeholder_tokens: list[str], -- cgit v1.2.3-70-g09d2