From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- training/functional.py | 58 ++++++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 26 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -461,6 +461,10 @@ def train_loop( num_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], + avg_loss: AverageMeter = AverageMeter(), + avg_acc: AverageMeter = AverageMeter(), + avg_loss_val: AverageMeter = AverageMeter(), + avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) @@ -472,14 +476,8 @@ def train_loop( global_step = 0 cache = {} - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - best_acc = 0.0 - best_acc_val = 0.0 + best_acc = avg_acc.avg + best_acc_val = avg_acc_val.avg local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -544,12 +542,12 @@ def train_loop( accelerator.backward(loss) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.item(), bsz) + avg_acc.update(acc.item(), bsz) logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), + "train/loss": avg_loss.avg, + "train/acc": avg_acc.avg, "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), } @@ -603,47 +601,47 @@ def train_loop( loss = loss.detach_() acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss.item(), bsz) + cur_acc_val.update(acc.item(), bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss.item(), bsz) + avg_acc_val.update(acc.item(), bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), + "val/loss": avg_loss_val.avg, + "val/acc": avg_acc_val.avg, "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() + logs["val/cur_loss"] = cur_loss_val.avg + logs["val/cur_acc"] = cur_acc_val.avg accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: + if avg_acc_val.avg > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc_val = avg_acc_val.avg.item() + best_acc_val = avg_acc_val.avg else: if accelerator.is_main_process: - if avg_acc.avg.item() > best_acc and milestone_checkpoints: + if avg_acc.avg > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc = avg_acc.avg.item() + best_acc = avg_acc.avg # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -688,6 +686,10 @@ def train( offset_noise_strength: float = 0.15, disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, + avg_loss: AverageMeter = AverageMeter(), + avg_acc: AverageMeter = AverageMeter(), + avg_loss_val: AverageMeter = AverageMeter(), + avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( @@ -737,6 +739,10 @@ def train( num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, group_labels=group_labels, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, callbacks=callbacks, ) -- cgit v1.2.3-54-g00ecf