diff options
-rw-r--r-- | train_dreambooth.py | 1 | ||||
-rw-r--r-- | train_lora.py | 1 | ||||
-rw-r--r-- | train_ti.py | 1 | ||||
-rw-r--r-- | training/functional.py | 18 | ||||
-rw-r--r-- | training/optimization.py | 7 |
5 files changed, 18 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4456bd1..48b7926 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -539,7 +539,6 @@ def main(): | |||
539 | elif args.optimizer == 'adafactor': | 539 | elif args.optimizer == 'adafactor': |
540 | create_optimizer = partial( | 540 | create_optimizer = partial( |
541 | transformers.optimization.Adafactor, | 541 | transformers.optimization.Adafactor, |
542 | beta1=args.adam_beta1, | ||
543 | weight_decay=args.adam_weight_decay, | 542 | weight_decay=args.adam_weight_decay, |
544 | scale_parameter=True, | 543 | scale_parameter=True, |
545 | relative_step=True, | 544 | relative_step=True, |
diff --git a/train_lora.py b/train_lora.py index f8dccae..8fc2d69 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -571,7 +571,6 @@ def main(): | |||
571 | elif args.optimizer == 'adafactor': | 571 | elif args.optimizer == 'adafactor': |
572 | create_optimizer = partial( | 572 | create_optimizer = partial( |
573 | transformers.optimization.Adafactor, | 573 | transformers.optimization.Adafactor, |
574 | beta1=args.adam_beta1, | ||
575 | weight_decay=args.adam_weight_decay, | 574 | weight_decay=args.adam_weight_decay, |
576 | scale_parameter=True, | 575 | scale_parameter=True, |
577 | relative_step=True, | 576 | relative_step=True, |
diff --git a/train_ti.py b/train_ti.py index 274a1ca..5482326 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -669,7 +669,6 @@ def main(): | |||
669 | elif args.optimizer == 'adafactor': | 669 | elif args.optimizer == 'adafactor': |
670 | create_optimizer = partial( | 670 | create_optimizer = partial( |
671 | transformers.optimization.Adafactor, | 671 | transformers.optimization.Adafactor, |
672 | beta1=args.adam_beta1, | ||
673 | weight_decay=args.adam_weight_decay, | 672 | weight_decay=args.adam_weight_decay, |
674 | scale_parameter=True, | 673 | scale_parameter=True, |
675 | relative_step=True, | 674 | relative_step=True, |
diff --git a/training/functional.py b/training/functional.py index ac43847..7104a88 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -484,12 +484,16 @@ def train_loop( | |||
484 | avg_loss.update(loss.detach_(), bsz) | 484 | avg_loss.update(loss.detach_(), bsz) |
485 | avg_acc.update(acc.detach_(), bsz) | 485 | avg_acc.update(acc.detach_(), bsz) |
486 | 486 | ||
487 | lr = lr_scheduler.get_last_lr()[0] | ||
488 | if torch.is_tensor(lr): | ||
489 | lr = lr.item() | ||
490 | |||
487 | logs = { | 491 | logs = { |
488 | "train/loss": avg_loss.avg.item(), | 492 | "train/loss": avg_loss.avg.item(), |
489 | "train/acc": avg_acc.avg.item(), | 493 | "train/acc": avg_acc.avg.item(), |
490 | "train/cur_loss": loss.item(), | 494 | "train/cur_loss": loss.item(), |
491 | "train/cur_acc": acc.item(), | 495 | "train/cur_acc": acc.item(), |
492 | "lr": lr_scheduler.get_last_lr()[0], | 496 | "lr": lr, |
493 | } | 497 | } |
494 | if isDadaptation: | 498 | if isDadaptation: |
495 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 499 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
@@ -498,13 +502,13 @@ def train_loop( | |||
498 | local_progress_bar.set_postfix(**logs) | 502 | local_progress_bar.set_postfix(**logs) |
499 | 503 | ||
500 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 504 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
501 | before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 505 | before_optimize_result = on_before_optimize(lr, epoch) |
502 | 506 | ||
503 | optimizer.step() | 507 | optimizer.step() |
504 | lr_scheduler.step() | 508 | lr_scheduler.step() |
505 | optimizer.zero_grad(set_to_none=True) | 509 | optimizer.zero_grad(set_to_none=True) |
506 | 510 | ||
507 | on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) | 511 | on_after_optimize(before_optimize_result, lr) |
508 | 512 | ||
509 | local_progress_bar.update(1) | 513 | local_progress_bar.update(1) |
510 | global_progress_bar.update(1) | 514 | global_progress_bar.update(1) |
@@ -518,9 +522,13 @@ def train_loop( | |||
518 | 522 | ||
519 | accelerator.wait_for_everyone() | 523 | accelerator.wait_for_everyone() |
520 | 524 | ||
521 | lrs.append(lr_scheduler.get_last_lr()[0]) | 525 | lr = lr_scheduler.get_last_lr()[0] |
526 | if torch.is_tensor(lr): | ||
527 | lr = lr.item | ||
528 | |||
529 | lrs.append(lr) | ||
522 | 530 | ||
523 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 531 | on_after_epoch(lr) |
524 | 532 | ||
525 | if val_dataloader is not None: | 533 | if val_dataloader is not None: |
526 | model.eval() | 534 | model.eval() |
diff --git a/training/optimization.py b/training/optimization.py index 53d0a6d..d22a900 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | 9 | from transformers.optimization import get_adafactor_schedule |
10 | 10 | ||
11 | 11 | ||
12 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -150,7 +150,10 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 150 | num_cycles=cycles, |
151 | ) | 151 | ) |
152 | elif id == "adafactor": | 152 | elif id == "adafactor": |
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | 153 | lr_scheduler = get_adafactor_schedule( |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
154 | else: | 157 | else: |
155 | lr_scheduler = get_scheduler_( | 158 | lr_scheduler = get_scheduler_( |
156 | id, | 159 | id, |