summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py
index fb135c4..c373ac9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -7,7 +7,6 @@ from pathlib import Path
7import itertools 7import itertools
8 8
9import torch 9import torch
10import torch.nn as nn
11import torch.nn.functional as F 10import torch.nn.functional as F
12from torch.utils.data import DataLoader 11from torch.utils.data import DataLoader
13 12
@@ -373,8 +372,12 @@ def train_loop(
373 avg_loss_val = AverageMeter() 372 avg_loss_val = AverageMeter()
374 avg_acc_val = AverageMeter() 373 avg_acc_val = AverageMeter()
375 374
376 max_acc = 0.0 375 best_acc = 0.0
377 max_acc_val = 0.0 376 best_acc_val = 0.0
377
378 lrs = []
379 losses = []
380 accs = []
378 381
379 local_progress_bar = tqdm( 382 local_progress_bar = tqdm(
380 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 383 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -457,6 +460,8 @@ def train_loop(
457 460
458 accelerator.wait_for_everyone() 461 accelerator.wait_for_everyone()
459 462
463 lrs.append(lr_scheduler.get_last_lr()[0])
464
460 on_after_epoch(lr_scheduler.get_last_lr()[0]) 465 on_after_epoch(lr_scheduler.get_last_lr()[0])
461 466
462 if val_dataloader is not None: 467 if val_dataloader is not None:
@@ -498,18 +503,24 @@ def train_loop(
498 global_progress_bar.clear() 503 global_progress_bar.clear()
499 504
500 if accelerator.is_main_process: 505 if accelerator.is_main_process:
501 if avg_acc_val.avg.item() > max_acc_val: 506 if avg_acc_val.avg.item() > best_acc_val:
502 accelerator.print( 507 accelerator.print(
503 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 508 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
504 on_checkpoint(global_step + global_step_offset, "milestone") 509 on_checkpoint(global_step + global_step_offset, "milestone")
505 max_acc_val = avg_acc_val.avg.item() 510 best_acc_val = avg_acc_val.avg.item()
511
512 losses.append(avg_loss_val.avg.item())
513 accs.append(avg_acc_val.avg.item())
506 else: 514 else:
507 if accelerator.is_main_process: 515 if accelerator.is_main_process:
508 if avg_acc.avg.item() > max_acc: 516 if avg_acc.avg.item() > best_acc:
509 accelerator.print( 517 accelerator.print(
510 f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") 518 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}")
511 on_checkpoint(global_step + global_step_offset, "milestone") 519 on_checkpoint(global_step + global_step_offset, "milestone")
512 max_acc = avg_acc.avg.item() 520 best_acc = avg_acc.avg.item()
521
522 losses.append(avg_loss.avg.item())
523 accs.append(avg_acc.avg.item())
513 524
514 # Create the pipeline using using the trained modules and save it. 525 # Create the pipeline using using the trained modules and save it.
515 if accelerator.is_main_process: 526 if accelerator.is_main_process:
@@ -523,6 +534,8 @@ def train_loop(
523 on_checkpoint(global_step + global_step_offset, "end") 534 on_checkpoint(global_step + global_step_offset, "end")
524 raise KeyboardInterrupt 535 raise KeyboardInterrupt
525 536
537 return lrs, losses, accs
538
526 539
527def train( 540def train(
528 accelerator: Accelerator, 541 accelerator: Accelerator,
@@ -582,7 +595,7 @@ def train(
582 if accelerator.is_main_process: 595 if accelerator.is_main_process:
583 accelerator.init_trackers(project) 596 accelerator.init_trackers(project)
584 597
585 train_loop( 598 metrics = train_loop(
586 accelerator=accelerator, 599 accelerator=accelerator,
587 optimizer=optimizer, 600 optimizer=optimizer,
588 lr_scheduler=lr_scheduler, 601 lr_scheduler=lr_scheduler,
@@ -598,3 +611,5 @@ def train(
598 611
599 accelerator.end_training() 612 accelerator.end_training()
600 accelerator.free_memory() 613 accelerator.free_memory()
614
615 return metrics