summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-28 23:51:40 +0200
committerVolpeon <git@volpeon.ink>2023-04-28 23:51:40 +0200
commit8e0e47217b7e18288eaa9462c6bbecf7387f3d89 (patch)
tree35f39eb57b55f0be6752e70110541e8c96351963 /training
parentFixed loss/acc logging (diff)
downloadtextual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.gz
textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.bz2
textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.zip
Support torch.compile
Diffstat (limited to 'training')
-rw-r--r--training/functional.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index 6ae35a0..e7cc20f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -624,24 +624,24 @@ def train_loop(
624 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
625 625
626 if accelerator.is_main_process: 626 if accelerator.is_main_process:
627 if avg_acc_val.avg > best_acc_val and milestone_checkpoints: 627 if avg_acc_val.max > best_acc_val and milestone_checkpoints:
628 local_progress_bar.clear() 628 local_progress_bar.clear()
629 global_progress_bar.clear() 629 global_progress_bar.clear()
630 630
631 accelerator.print( 631 accelerator.print(
632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") 632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}")
633 on_checkpoint(global_step, "milestone") 633 on_checkpoint(global_step, "milestone")
634 best_acc_val = avg_acc_val.avg 634 best_acc_val = avg_acc_val.max
635 else: 635 else:
636 if accelerator.is_main_process: 636 if accelerator.is_main_process:
637 if avg_acc.avg > best_acc and milestone_checkpoints: 637 if avg_acc.max > best_acc and milestone_checkpoints:
638 local_progress_bar.clear() 638 local_progress_bar.clear()
639 global_progress_bar.clear() 639 global_progress_bar.clear()
640 640
641 accelerator.print( 641 accelerator.print(
642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") 642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}")
643 on_checkpoint(global_step, "milestone") 643 on_checkpoint(global_step, "milestone")
644 best_acc = avg_acc.avg 644 best_acc = avg_acc.max
645 645
646 # Create the pipeline using using the trained modules and save it. 646 # Create the pipeline using using the trained modules and save it.
647 if accelerator.is_main_process: 647 if accelerator.is_main_process:
@@ -699,6 +699,10 @@ def train(
699 vae.requires_grad_(False) 699 vae.requires_grad_(False)
700 vae.eval() 700 vae.eval()
701 701
702 unet = torch.compile(unet)
703 text_encoder = torch.compile(text_encoder)
704 vae = torch.compile(vae)
705
702 callbacks = strategy.callbacks( 706 callbacks = strategy.callbacks(
703 accelerator=accelerator, 707 accelerator=accelerator,
704 unet=unet, 708 unet=unet,