diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index fb135c4..c373ac9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -7,7 +7,6 @@ from pathlib import Path | |||
| 7 | import itertools | 7 | import itertools |
| 8 | 8 | ||
| 9 | import torch | 9 | import torch |
| 10 | import torch.nn as nn | ||
| 11 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
| 12 | from torch.utils.data import DataLoader | 11 | from torch.utils.data import DataLoader |
| 13 | 12 | ||
| @@ -373,8 +372,12 @@ def train_loop( | |||
| 373 | avg_loss_val = AverageMeter() | 372 | avg_loss_val = AverageMeter() |
| 374 | avg_acc_val = AverageMeter() | 373 | avg_acc_val = AverageMeter() |
| 375 | 374 | ||
| 376 | max_acc = 0.0 | 375 | best_acc = 0.0 |
| 377 | max_acc_val = 0.0 | 376 | best_acc_val = 0.0 |
| 377 | |||
| 378 | lrs = [] | ||
| 379 | losses = [] | ||
| 380 | accs = [] | ||
| 378 | 381 | ||
| 379 | local_progress_bar = tqdm( | 382 | local_progress_bar = tqdm( |
| 380 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 383 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -457,6 +460,8 @@ def train_loop( | |||
| 457 | 460 | ||
| 458 | accelerator.wait_for_everyone() | 461 | accelerator.wait_for_everyone() |
| 459 | 462 | ||
| 463 | lrs.append(lr_scheduler.get_last_lr()[0]) | ||
| 464 | |||
| 460 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 465 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
| 461 | 466 | ||
| 462 | if val_dataloader is not None: | 467 | if val_dataloader is not None: |
| @@ -498,18 +503,24 @@ def train_loop( | |||
| 498 | global_progress_bar.clear() | 503 | global_progress_bar.clear() |
| 499 | 504 | ||
| 500 | if accelerator.is_main_process: | 505 | if accelerator.is_main_process: |
| 501 | if avg_acc_val.avg.item() > max_acc_val: | 506 | if avg_acc_val.avg.item() > best_acc_val: |
| 502 | accelerator.print( | 507 | accelerator.print( |
| 503 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 508 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 504 | on_checkpoint(global_step + global_step_offset, "milestone") | 509 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 505 | max_acc_val = avg_acc_val.avg.item() | 510 | best_acc_val = avg_acc_val.avg.item() |
| 511 | |||
| 512 | losses.append(avg_loss_val.avg.item()) | ||
| 513 | accs.append(avg_acc_val.avg.item()) | ||
| 506 | else: | 514 | else: |
| 507 | if accelerator.is_main_process: | 515 | if accelerator.is_main_process: |
| 508 | if avg_acc.avg.item() > max_acc: | 516 | if avg_acc.avg.item() > best_acc: |
| 509 | accelerator.print( | 517 | accelerator.print( |
| 510 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | 518 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
| 511 | on_checkpoint(global_step + global_step_offset, "milestone") | 519 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 512 | max_acc = avg_acc.avg.item() | 520 | best_acc = avg_acc.avg.item() |
| 521 | |||
| 522 | losses.append(avg_loss.avg.item()) | ||
| 523 | accs.append(avg_acc.avg.item()) | ||
| 513 | 524 | ||
| 514 | # Create the pipeline using using the trained modules and save it. | 525 | # Create the pipeline using using the trained modules and save it. |
| 515 | if accelerator.is_main_process: | 526 | if accelerator.is_main_process: |
| @@ -523,6 +534,8 @@ def train_loop( | |||
| 523 | on_checkpoint(global_step + global_step_offset, "end") | 534 | on_checkpoint(global_step + global_step_offset, "end") |
| 524 | raise KeyboardInterrupt | 535 | raise KeyboardInterrupt |
| 525 | 536 | ||
| 537 | return lrs, losses, accs | ||
| 538 | |||
| 526 | 539 | ||
| 527 | def train( | 540 | def train( |
| 528 | accelerator: Accelerator, | 541 | accelerator: Accelerator, |
| @@ -582,7 +595,7 @@ def train( | |||
| 582 | if accelerator.is_main_process: | 595 | if accelerator.is_main_process: |
| 583 | accelerator.init_trackers(project) | 596 | accelerator.init_trackers(project) |
| 584 | 597 | ||
| 585 | train_loop( | 598 | metrics = train_loop( |
| 586 | accelerator=accelerator, | 599 | accelerator=accelerator, |
| 587 | optimizer=optimizer, | 600 | optimizer=optimizer, |
| 588 | lr_scheduler=lr_scheduler, | 601 | lr_scheduler=lr_scheduler, |
| @@ -598,3 +611,5 @@ def train( | |||
| 598 | 611 | ||
| 599 | accelerator.end_training() | 612 | accelerator.end_training() |
| 600 | accelerator.free_memory() | 613 | accelerator.free_memory() |
| 614 | |||
| 615 | return metrics | ||
