summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py
index ac43847..7104a88 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -484,12 +484,16 @@ def train_loop(
484 avg_loss.update(loss.detach_(), bsz) 484 avg_loss.update(loss.detach_(), bsz)
485 avg_acc.update(acc.detach_(), bsz) 485 avg_acc.update(acc.detach_(), bsz)
486 486
487 lr = lr_scheduler.get_last_lr()[0]
488 if torch.is_tensor(lr):
489 lr = lr.item()
490
487 logs = { 491 logs = {
488 "train/loss": avg_loss.avg.item(), 492 "train/loss": avg_loss.avg.item(),
489 "train/acc": avg_acc.avg.item(), 493 "train/acc": avg_acc.avg.item(),
490 "train/cur_loss": loss.item(), 494 "train/cur_loss": loss.item(),
491 "train/cur_acc": acc.item(), 495 "train/cur_acc": acc.item(),
492 "lr": lr_scheduler.get_last_lr()[0], 496 "lr": lr,
493 } 497 }
494 if isDadaptation: 498 if isDadaptation:
495 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] 499 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
@@ -498,13 +502,13 @@ def train_loop(
498 local_progress_bar.set_postfix(**logs) 502 local_progress_bar.set_postfix(**logs)
499 503
500 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 504 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
501 before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 505 before_optimize_result = on_before_optimize(lr, epoch)
502 506
503 optimizer.step() 507 optimizer.step()
504 lr_scheduler.step() 508 lr_scheduler.step()
505 optimizer.zero_grad(set_to_none=True) 509 optimizer.zero_grad(set_to_none=True)
506 510
507 on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) 511 on_after_optimize(before_optimize_result, lr)
508 512
509 local_progress_bar.update(1) 513 local_progress_bar.update(1)
510 global_progress_bar.update(1) 514 global_progress_bar.update(1)
@@ -518,9 +522,13 @@ def train_loop(
518 522
519 accelerator.wait_for_everyone() 523 accelerator.wait_for_everyone()
520 524
521 lrs.append(lr_scheduler.get_last_lr()[0]) 525 lr = lr_scheduler.get_last_lr()[0]
526 if torch.is_tensor(lr):
527 lr = lr.item
528
529 lrs.append(lr)
522 530
523 on_after_epoch(lr_scheduler.get_last_lr()[0]) 531 on_after_epoch(lr)
524 532
525 if val_dataloader is not None: 533 if val_dataloader is not None:
526 model.eval() 534 model.eval()