From 8e0e47217b7e18288eaa9462c6bbecf7387f3d89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 28 Apr 2023 23:51:40 +0200 Subject: Support torch.compile --- training/functional.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 6ae35a0..e7cc20f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -624,24 +624,24 @@ def train_loop( accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg > best_acc_val and milestone_checkpoints: + if avg_acc_val.max > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc_val = avg_acc_val.avg + best_acc_val = avg_acc_val.max else: if accelerator.is_main_process: - if avg_acc.avg > best_acc and milestone_checkpoints: + if avg_acc.max > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc = avg_acc.avg + best_acc = avg_acc.max # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -699,6 +699,10 @@ def train( vae.requires_grad_(False) vae.eval() + unet = torch.compile(unet) + text_encoder = torch.compile(text_encoder) + vae = torch.compile(vae) + callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, -- cgit v1.2.3-70-g09d2