summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 15:54:40 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 15:54:40 +0200
commita551a9ac2edd1dc59828749a5e5d73a65b3c9ce7 (patch)
tree7ccca7f3a70b2b34706ddb849e37924aa6ee88e9 /training
parentAdd support for Adafactor, add TI initializer noise (diff)
downloadtextual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.tar.gz
textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.tar.bz2
textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.zip
Update
Diffstat (limited to 'training')
-rw-r--r--training/functional.py18
-rw-r--r--training/optimization.py7
2 files changed, 18 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py
index ac43847..7104a88 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -484,12 +484,16 @@ def train_loop(
484 avg_loss.update(loss.detach_(), bsz) 484 avg_loss.update(loss.detach_(), bsz)
485 avg_acc.update(acc.detach_(), bsz) 485 avg_acc.update(acc.detach_(), bsz)
486 486
487 lr = lr_scheduler.get_last_lr()[0]
488 if torch.is_tensor(lr):
489 lr = lr.item()
490
487 logs = { 491 logs = {
488 "train/loss": avg_loss.avg.item(), 492 "train/loss": avg_loss.avg.item(),
489 "train/acc": avg_acc.avg.item(), 493 "train/acc": avg_acc.avg.item(),
490 "train/cur_loss": loss.item(), 494 "train/cur_loss": loss.item(),
491 "train/cur_acc": acc.item(), 495 "train/cur_acc": acc.item(),
492 "lr": lr_scheduler.get_last_lr()[0], 496 "lr": lr,
493 } 497 }
494 if isDadaptation: 498 if isDadaptation:
495 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] 499 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
@@ -498,13 +502,13 @@ def train_loop(
498 local_progress_bar.set_postfix(**logs) 502 local_progress_bar.set_postfix(**logs)
499 503
500 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 504 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
501 before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 505 before_optimize_result = on_before_optimize(lr, epoch)
502 506
503 optimizer.step() 507 optimizer.step()
504 lr_scheduler.step() 508 lr_scheduler.step()
505 optimizer.zero_grad(set_to_none=True) 509 optimizer.zero_grad(set_to_none=True)
506 510
507 on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) 511 on_after_optimize(before_optimize_result, lr)
508 512
509 local_progress_bar.update(1) 513 local_progress_bar.update(1)
510 global_progress_bar.update(1) 514 global_progress_bar.update(1)
@@ -518,9 +522,13 @@ def train_loop(
518 522
519 accelerator.wait_for_everyone() 523 accelerator.wait_for_everyone()
520 524
521 lrs.append(lr_scheduler.get_last_lr()[0]) 525 lr = lr_scheduler.get_last_lr()[0]
526 if torch.is_tensor(lr):
527 lr = lr.item
528
529 lrs.append(lr)
522 530
523 on_after_epoch(lr_scheduler.get_last_lr()[0]) 531 on_after_epoch(lr)
524 532
525 if val_dataloader is not None: 533 if val_dataloader is not None:
526 model.eval() 534 model.eval()
diff --git a/training/optimization.py b/training/optimization.py
index 53d0a6d..d22a900 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers 9from transformers.optimization import get_adafactor_schedule
10 10
11 11
12class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -150,7 +150,10 @@ def get_scheduler(
150 num_cycles=cycles, 150 num_cycles=cycles,
151 ) 151 )
152 elif id == "adafactor": 152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) 153 lr_scheduler = get_adafactor_schedule(
154 optimizer,
155 initial_lr=min_lr
156 )
154 else: 157 else:
155 lr_scheduler = get_scheduler_( 158 lr_scheduler = get_scheduler_(
156 id, 159 id,