diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index fb135c4..c373ac9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -7,7 +7,6 @@ from pathlib import Path | |||
7 | import itertools | 7 | import itertools |
8 | 8 | ||
9 | import torch | 9 | import torch |
10 | import torch.nn as nn | ||
11 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
12 | from torch.utils.data import DataLoader | 11 | from torch.utils.data import DataLoader |
13 | 12 | ||
@@ -373,8 +372,12 @@ def train_loop( | |||
373 | avg_loss_val = AverageMeter() | 372 | avg_loss_val = AverageMeter() |
374 | avg_acc_val = AverageMeter() | 373 | avg_acc_val = AverageMeter() |
375 | 374 | ||
376 | max_acc = 0.0 | 375 | best_acc = 0.0 |
377 | max_acc_val = 0.0 | 376 | best_acc_val = 0.0 |
377 | |||
378 | lrs = [] | ||
379 | losses = [] | ||
380 | accs = [] | ||
378 | 381 | ||
379 | local_progress_bar = tqdm( | 382 | local_progress_bar = tqdm( |
380 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 383 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -457,6 +460,8 @@ def train_loop( | |||
457 | 460 | ||
458 | accelerator.wait_for_everyone() | 461 | accelerator.wait_for_everyone() |
459 | 462 | ||
463 | lrs.append(lr_scheduler.get_last_lr()[0]) | ||
464 | |||
460 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 465 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
461 | 466 | ||
462 | if val_dataloader is not None: | 467 | if val_dataloader is not None: |
@@ -498,18 +503,24 @@ def train_loop( | |||
498 | global_progress_bar.clear() | 503 | global_progress_bar.clear() |
499 | 504 | ||
500 | if accelerator.is_main_process: | 505 | if accelerator.is_main_process: |
501 | if avg_acc_val.avg.item() > max_acc_val: | 506 | if avg_acc_val.avg.item() > best_acc_val: |
502 | accelerator.print( | 507 | accelerator.print( |
503 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 508 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
504 | on_checkpoint(global_step + global_step_offset, "milestone") | 509 | on_checkpoint(global_step + global_step_offset, "milestone") |
505 | max_acc_val = avg_acc_val.avg.item() | 510 | best_acc_val = avg_acc_val.avg.item() |
511 | |||
512 | losses.append(avg_loss_val.avg.item()) | ||
513 | accs.append(avg_acc_val.avg.item()) | ||
506 | else: | 514 | else: |
507 | if accelerator.is_main_process: | 515 | if accelerator.is_main_process: |
508 | if avg_acc.avg.item() > max_acc: | 516 | if avg_acc.avg.item() > best_acc: |
509 | accelerator.print( | 517 | accelerator.print( |
510 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | 518 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
511 | on_checkpoint(global_step + global_step_offset, "milestone") | 519 | on_checkpoint(global_step + global_step_offset, "milestone") |
512 | max_acc = avg_acc.avg.item() | 520 | best_acc = avg_acc.avg.item() |
521 | |||
522 | losses.append(avg_loss.avg.item()) | ||
523 | accs.append(avg_acc.avg.item()) | ||
513 | 524 | ||
514 | # Create the pipeline using using the trained modules and save it. | 525 | # Create the pipeline using using the trained modules and save it. |
515 | if accelerator.is_main_process: | 526 | if accelerator.is_main_process: |
@@ -523,6 +534,8 @@ def train_loop( | |||
523 | on_checkpoint(global_step + global_step_offset, "end") | 534 | on_checkpoint(global_step + global_step_offset, "end") |
524 | raise KeyboardInterrupt | 535 | raise KeyboardInterrupt |
525 | 536 | ||
537 | return lrs, losses, accs | ||
538 | |||
526 | 539 | ||
527 | def train( | 540 | def train( |
528 | accelerator: Accelerator, | 541 | accelerator: Accelerator, |
@@ -582,7 +595,7 @@ def train( | |||
582 | if accelerator.is_main_process: | 595 | if accelerator.is_main_process: |
583 | accelerator.init_trackers(project) | 596 | accelerator.init_trackers(project) |
584 | 597 | ||
585 | train_loop( | 598 | metrics = train_loop( |
586 | accelerator=accelerator, | 599 | accelerator=accelerator, |
587 | optimizer=optimizer, | 600 | optimizer=optimizer, |
588 | lr_scheduler=lr_scheduler, | 601 | lr_scheduler=lr_scheduler, |
@@ -598,3 +611,5 @@ def train( | |||
598 | 611 | ||
599 | accelerator.end_training() | 612 | accelerator.end_training() |
600 | accelerator.free_memory() | 613 | accelerator.free_memory() |
614 | |||
615 | return metrics | ||