summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py1
-rw-r--r--train_lora.py1
-rw-r--r--train_ti.py1
-rw-r--r--training/functional.py18
-rw-r--r--training/optimization.py7
5 files changed, 18 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 4456bd1..48b7926 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -539,7 +539,6 @@ def main():
539 elif args.optimizer == 'adafactor': 539 elif args.optimizer == 'adafactor':
540 create_optimizer = partial( 540 create_optimizer = partial(
541 transformers.optimization.Adafactor, 541 transformers.optimization.Adafactor,
542 beta1=args.adam_beta1,
543 weight_decay=args.adam_weight_decay, 542 weight_decay=args.adam_weight_decay,
544 scale_parameter=True, 543 scale_parameter=True,
545 relative_step=True, 544 relative_step=True,
diff --git a/train_lora.py b/train_lora.py
index f8dccae..8fc2d69 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -571,7 +571,6 @@ def main():
571 elif args.optimizer == 'adafactor': 571 elif args.optimizer == 'adafactor':
572 create_optimizer = partial( 572 create_optimizer = partial(
573 transformers.optimization.Adafactor, 573 transformers.optimization.Adafactor,
574 beta1=args.adam_beta1,
575 weight_decay=args.adam_weight_decay, 574 weight_decay=args.adam_weight_decay,
576 scale_parameter=True, 575 scale_parameter=True,
577 relative_step=True, 576 relative_step=True,
diff --git a/train_ti.py b/train_ti.py
index 274a1ca..5482326 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -669,7 +669,6 @@ def main():
669 elif args.optimizer == 'adafactor': 669 elif args.optimizer == 'adafactor':
670 create_optimizer = partial( 670 create_optimizer = partial(
671 transformers.optimization.Adafactor, 671 transformers.optimization.Adafactor,
672 beta1=args.adam_beta1,
673 weight_decay=args.adam_weight_decay, 672 weight_decay=args.adam_weight_decay,
674 scale_parameter=True, 673 scale_parameter=True,
675 relative_step=True, 674 relative_step=True,
diff --git a/training/functional.py b/training/functional.py
index ac43847..7104a88 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -484,12 +484,16 @@ def train_loop(
484 avg_loss.update(loss.detach_(), bsz) 484 avg_loss.update(loss.detach_(), bsz)
485 avg_acc.update(acc.detach_(), bsz) 485 avg_acc.update(acc.detach_(), bsz)
486 486
487 lr = lr_scheduler.get_last_lr()[0]
488 if torch.is_tensor(lr):
489 lr = lr.item()
490
487 logs = { 491 logs = {
488 "train/loss": avg_loss.avg.item(), 492 "train/loss": avg_loss.avg.item(),
489 "train/acc": avg_acc.avg.item(), 493 "train/acc": avg_acc.avg.item(),
490 "train/cur_loss": loss.item(), 494 "train/cur_loss": loss.item(),
491 "train/cur_acc": acc.item(), 495 "train/cur_acc": acc.item(),
492 "lr": lr_scheduler.get_last_lr()[0], 496 "lr": lr,
493 } 497 }
494 if isDadaptation: 498 if isDadaptation:
495 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] 499 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
@@ -498,13 +502,13 @@ def train_loop(
498 local_progress_bar.set_postfix(**logs) 502 local_progress_bar.set_postfix(**logs)
499 503
500 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 504 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
501 before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 505 before_optimize_result = on_before_optimize(lr, epoch)
502 506
503 optimizer.step() 507 optimizer.step()
504 lr_scheduler.step() 508 lr_scheduler.step()
505 optimizer.zero_grad(set_to_none=True) 509 optimizer.zero_grad(set_to_none=True)
506 510
507 on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) 511 on_after_optimize(before_optimize_result, lr)
508 512
509 local_progress_bar.update(1) 513 local_progress_bar.update(1)
510 global_progress_bar.update(1) 514 global_progress_bar.update(1)
@@ -518,9 +522,13 @@ def train_loop(
518 522
519 accelerator.wait_for_everyone() 523 accelerator.wait_for_everyone()
520 524
521 lrs.append(lr_scheduler.get_last_lr()[0]) 525 lr = lr_scheduler.get_last_lr()[0]
526 if torch.is_tensor(lr):
527 lr = lr.item
528
529 lrs.append(lr)
522 530
523 on_after_epoch(lr_scheduler.get_last_lr()[0]) 531 on_after_epoch(lr)
524 532
525 if val_dataloader is not None: 533 if val_dataloader is not None:
526 model.eval() 534 model.eval()
diff --git a/training/optimization.py b/training/optimization.py
index 53d0a6d..d22a900 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers 9from transformers.optimization import get_adafactor_schedule
10 10
11 11
12class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -150,7 +150,10 @@ def get_scheduler(
150 num_cycles=cycles, 150 num_cycles=cycles,
151 ) 151 )
152 elif id == "adafactor": 152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) 153 lr_scheduler = get_adafactor_schedule(
154 optimizer,
155 initial_lr=min_lr
156 )
154 else: 157 else:
155 lr_scheduler = get_scheduler_( 158 lr_scheduler = get_scheduler_(
156 id, 159 id,