diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py index ac43847..7104a88 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -484,12 +484,16 @@ def train_loop( | |||
484 | avg_loss.update(loss.detach_(), bsz) | 484 | avg_loss.update(loss.detach_(), bsz) |
485 | avg_acc.update(acc.detach_(), bsz) | 485 | avg_acc.update(acc.detach_(), bsz) |
486 | 486 | ||
487 | lr = lr_scheduler.get_last_lr()[0] | ||
488 | if torch.is_tensor(lr): | ||
489 | lr = lr.item() | ||
490 | |||
487 | logs = { | 491 | logs = { |
488 | "train/loss": avg_loss.avg.item(), | 492 | "train/loss": avg_loss.avg.item(), |
489 | "train/acc": avg_acc.avg.item(), | 493 | "train/acc": avg_acc.avg.item(), |
490 | "train/cur_loss": loss.item(), | 494 | "train/cur_loss": loss.item(), |
491 | "train/cur_acc": acc.item(), | 495 | "train/cur_acc": acc.item(), |
492 | "lr": lr_scheduler.get_last_lr()[0], | 496 | "lr": lr, |
493 | } | 497 | } |
494 | if isDadaptation: | 498 | if isDadaptation: |
495 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 499 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
@@ -498,13 +502,13 @@ def train_loop( | |||
498 | local_progress_bar.set_postfix(**logs) | 502 | local_progress_bar.set_postfix(**logs) |
499 | 503 | ||
500 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 504 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
501 | before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 505 | before_optimize_result = on_before_optimize(lr, epoch) |
502 | 506 | ||
503 | optimizer.step() | 507 | optimizer.step() |
504 | lr_scheduler.step() | 508 | lr_scheduler.step() |
505 | optimizer.zero_grad(set_to_none=True) | 509 | optimizer.zero_grad(set_to_none=True) |
506 | 510 | ||
507 | on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) | 511 | on_after_optimize(before_optimize_result, lr) |
508 | 512 | ||
509 | local_progress_bar.update(1) | 513 | local_progress_bar.update(1) |
510 | global_progress_bar.update(1) | 514 | global_progress_bar.update(1) |
@@ -518,9 +522,13 @@ def train_loop( | |||
518 | 522 | ||
519 | accelerator.wait_for_everyone() | 523 | accelerator.wait_for_everyone() |
520 | 524 | ||
521 | lrs.append(lr_scheduler.get_last_lr()[0]) | 525 | lr = lr_scheduler.get_last_lr()[0] |
526 | if torch.is_tensor(lr): | ||
527 | lr = lr.item | ||
528 | |||
529 | lrs.append(lr) | ||
522 | 530 | ||
523 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 531 | on_after_epoch(lr) |
524 | 532 | ||
525 | if val_dataloader is not None: | 533 | if val_dataloader is not None: |
526 | model.eval() | 534 | model.eval() |