diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 6ae35a0..e7cc20f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -624,24 +624,24 @@ def train_loop( | |||
| 624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
| 625 | 625 | ||
| 626 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
| 627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.max > best_acc_val and milestone_checkpoints: |
| 628 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
| 629 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
| 630 | 630 | ||
| 631 | accelerator.print( | 631 | accelerator.print( |
| 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
| 633 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
| 634 | best_acc_val = avg_acc_val.avg | 634 | best_acc_val = avg_acc_val.max |
| 635 | else: | 635 | else: |
| 636 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
| 637 | if avg_acc.avg > best_acc and milestone_checkpoints: | 637 | if avg_acc.max > best_acc and milestone_checkpoints: |
| 638 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
| 639 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
| 640 | 640 | ||
| 641 | accelerator.print( | 641 | accelerator.print( |
| 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
| 643 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
| 644 | best_acc = avg_acc.avg | 644 | best_acc = avg_acc.max |
| 645 | 645 | ||
| 646 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
| 647 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
| @@ -699,6 +699,10 @@ def train( | |||
| 699 | vae.requires_grad_(False) | 699 | vae.requires_grad_(False) |
| 700 | vae.eval() | 700 | vae.eval() |
| 701 | 701 | ||
| 702 | unet = torch.compile(unet) | ||
| 703 | text_encoder = torch.compile(text_encoder) | ||
| 704 | vae = torch.compile(vae) | ||
| 705 | |||
| 702 | callbacks = strategy.callbacks( | 706 | callbacks = strategy.callbacks( |
| 703 | accelerator=accelerator, | 707 | accelerator=accelerator, |
| 704 | unet=unet, | 708 | unet=unet, |
