diff options
author | Volpeon <git@volpeon.ink> | 2023-04-28 23:51:40 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-28 23:51:40 +0200 |
commit | 8e0e47217b7e18288eaa9462c6bbecf7387f3d89 (patch) | |
tree | 35f39eb57b55f0be6752e70110541e8c96351963 /training | |
parent | Fixed loss/acc logging (diff) | |
download | textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.gz textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.bz2 textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.zip |
Support torch.compile
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 6ae35a0..e7cc20f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -624,24 +624,24 @@ def train_loop( | |||
624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
625 | 625 | ||
626 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.max > best_acc_val and milestone_checkpoints: |
628 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
629 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
630 | 630 | ||
631 | accelerator.print( | 631 | accelerator.print( |
632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
633 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
634 | best_acc_val = avg_acc_val.avg | 634 | best_acc_val = avg_acc_val.max |
635 | else: | 635 | else: |
636 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
637 | if avg_acc.avg > best_acc and milestone_checkpoints: | 637 | if avg_acc.max > best_acc and milestone_checkpoints: |
638 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
639 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
640 | 640 | ||
641 | accelerator.print( | 641 | accelerator.print( |
642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
643 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
644 | best_acc = avg_acc.avg | 644 | best_acc = avg_acc.max |
645 | 645 | ||
646 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
647 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
@@ -699,6 +699,10 @@ def train( | |||
699 | vae.requires_grad_(False) | 699 | vae.requires_grad_(False) |
700 | vae.eval() | 700 | vae.eval() |
701 | 701 | ||
702 | unet = torch.compile(unet) | ||
703 | text_encoder = torch.compile(text_encoder) | ||
704 | vae = torch.compile(vae) | ||
705 | |||
702 | callbacks = strategy.callbacks( | 706 | callbacks = strategy.callbacks( |
703 | accelerator=accelerator, | 707 | accelerator=accelerator, |
704 | unet=unet, | 708 | unet=unet, |