diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py index ac43847..7104a88 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -484,12 +484,16 @@ def train_loop( | |||
| 484 | avg_loss.update(loss.detach_(), bsz) | 484 | avg_loss.update(loss.detach_(), bsz) |
| 485 | avg_acc.update(acc.detach_(), bsz) | 485 | avg_acc.update(acc.detach_(), bsz) |
| 486 | 486 | ||
| 487 | lr = lr_scheduler.get_last_lr()[0] | ||
| 488 | if torch.is_tensor(lr): | ||
| 489 | lr = lr.item() | ||
| 490 | |||
| 487 | logs = { | 491 | logs = { |
| 488 | "train/loss": avg_loss.avg.item(), | 492 | "train/loss": avg_loss.avg.item(), |
| 489 | "train/acc": avg_acc.avg.item(), | 493 | "train/acc": avg_acc.avg.item(), |
| 490 | "train/cur_loss": loss.item(), | 494 | "train/cur_loss": loss.item(), |
| 491 | "train/cur_acc": acc.item(), | 495 | "train/cur_acc": acc.item(), |
| 492 | "lr": lr_scheduler.get_last_lr()[0], | 496 | "lr": lr, |
| 493 | } | 497 | } |
| 494 | if isDadaptation: | 498 | if isDadaptation: |
| 495 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 499 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
| @@ -498,13 +502,13 @@ def train_loop( | |||
| 498 | local_progress_bar.set_postfix(**logs) | 502 | local_progress_bar.set_postfix(**logs) |
| 499 | 503 | ||
| 500 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 504 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
| 501 | before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 505 | before_optimize_result = on_before_optimize(lr, epoch) |
| 502 | 506 | ||
| 503 | optimizer.step() | 507 | optimizer.step() |
| 504 | lr_scheduler.step() | 508 | lr_scheduler.step() |
| 505 | optimizer.zero_grad(set_to_none=True) | 509 | optimizer.zero_grad(set_to_none=True) |
| 506 | 510 | ||
| 507 | on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) | 511 | on_after_optimize(before_optimize_result, lr) |
| 508 | 512 | ||
| 509 | local_progress_bar.update(1) | 513 | local_progress_bar.update(1) |
| 510 | global_progress_bar.update(1) | 514 | global_progress_bar.update(1) |
| @@ -518,9 +522,13 @@ def train_loop( | |||
| 518 | 522 | ||
| 519 | accelerator.wait_for_everyone() | 523 | accelerator.wait_for_everyone() |
| 520 | 524 | ||
| 521 | lrs.append(lr_scheduler.get_last_lr()[0]) | 525 | lr = lr_scheduler.get_last_lr()[0] |
| 526 | if torch.is_tensor(lr): | ||
| 527 | lr = lr.item | ||
| 528 | |||
| 529 | lrs.append(lr) | ||
| 522 | 530 | ||
| 523 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 531 | on_after_epoch(lr) |
| 524 | 532 | ||
| 525 | if val_dataloader is not None: | 533 | if val_dataloader is not None: |
| 526 | model.eval() | 534 | model.eval() |
