diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 18 | ||||
-rw-r--r-- | training/optimization.py | 7 |
2 files changed, 18 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py index ac43847..7104a88 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -484,12 +484,16 @@ def train_loop( | |||
484 | avg_loss.update(loss.detach_(), bsz) | 484 | avg_loss.update(loss.detach_(), bsz) |
485 | avg_acc.update(acc.detach_(), bsz) | 485 | avg_acc.update(acc.detach_(), bsz) |
486 | 486 | ||
487 | lr = lr_scheduler.get_last_lr()[0] | ||
488 | if torch.is_tensor(lr): | ||
489 | lr = lr.item() | ||
490 | |||
487 | logs = { | 491 | logs = { |
488 | "train/loss": avg_loss.avg.item(), | 492 | "train/loss": avg_loss.avg.item(), |
489 | "train/acc": avg_acc.avg.item(), | 493 | "train/acc": avg_acc.avg.item(), |
490 | "train/cur_loss": loss.item(), | 494 | "train/cur_loss": loss.item(), |
491 | "train/cur_acc": acc.item(), | 495 | "train/cur_acc": acc.item(), |
492 | "lr": lr_scheduler.get_last_lr()[0], | 496 | "lr": lr, |
493 | } | 497 | } |
494 | if isDadaptation: | 498 | if isDadaptation: |
495 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 499 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
@@ -498,13 +502,13 @@ def train_loop( | |||
498 | local_progress_bar.set_postfix(**logs) | 502 | local_progress_bar.set_postfix(**logs) |
499 | 503 | ||
500 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 504 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
501 | before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 505 | before_optimize_result = on_before_optimize(lr, epoch) |
502 | 506 | ||
503 | optimizer.step() | 507 | optimizer.step() |
504 | lr_scheduler.step() | 508 | lr_scheduler.step() |
505 | optimizer.zero_grad(set_to_none=True) | 509 | optimizer.zero_grad(set_to_none=True) |
506 | 510 | ||
507 | on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) | 511 | on_after_optimize(before_optimize_result, lr) |
508 | 512 | ||
509 | local_progress_bar.update(1) | 513 | local_progress_bar.update(1) |
510 | global_progress_bar.update(1) | 514 | global_progress_bar.update(1) |
@@ -518,9 +522,13 @@ def train_loop( | |||
518 | 522 | ||
519 | accelerator.wait_for_everyone() | 523 | accelerator.wait_for_everyone() |
520 | 524 | ||
521 | lrs.append(lr_scheduler.get_last_lr()[0]) | 525 | lr = lr_scheduler.get_last_lr()[0] |
526 | if torch.is_tensor(lr): | ||
527 | lr = lr.item | ||
528 | |||
529 | lrs.append(lr) | ||
522 | 530 | ||
523 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 531 | on_after_epoch(lr) |
524 | 532 | ||
525 | if val_dataloader is not None: | 533 | if val_dataloader is not None: |
526 | model.eval() | 534 | model.eval() |
diff --git a/training/optimization.py b/training/optimization.py index 53d0a6d..d22a900 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | 9 | from transformers.optimization import get_adafactor_schedule |
10 | 10 | ||
11 | 11 | ||
12 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -150,7 +150,10 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 150 | num_cycles=cycles, |
151 | ) | 151 | ) |
152 | elif id == "adafactor": | 152 | elif id == "adafactor": |
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | 153 | lr_scheduler = get_adafactor_schedule( |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
154 | else: | 157 | else: |
155 | lr_scheduler = get_scheduler_( | 158 | lr_scheduler = get_scheduler_( |
156 | id, | 159 | id, |