summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py57
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py12
-rw-r--r--training/strategy/ti.py2
4 files changed, 33 insertions, 40 deletions
diff --git a/training/functional.py b/training/functional.py
index 4d83df1..71b2fe9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -36,8 +36,8 @@ def const(result=None):
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
39 on_before_optimize: Callable[[float, int], Any] = const() 39 on_before_optimize: Callable[[int], Any] = const()
40 on_after_optimize: Callable[[Any, float], None] = const() 40 on_after_optimize: Callable[[Any, dict[str, float]], None] = const()
41 on_after_epoch: Callable[[], None] = const() 41 on_after_epoch: Callable[[], None] = const()
42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
43 on_sample: Callable[[int], None] = const() 43 on_sample: Callable[[int], None] = const()
@@ -422,6 +422,7 @@ def train_loop(
422 global_step_offset: int = 0, 422 global_step_offset: int = 0,
423 num_epochs: int = 100, 423 num_epochs: int = 100,
424 gradient_accumulation_steps: int = 1, 424 gradient_accumulation_steps: int = 1,
425 group_labels: list[str] = [],
425 callbacks: TrainingCallbacks = TrainingCallbacks(), 426 callbacks: TrainingCallbacks = TrainingCallbacks(),
426): 427):
427 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 428 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
@@ -442,10 +443,6 @@ def train_loop(
442 best_acc = 0.0 443 best_acc = 0.0
443 best_acc_val = 0.0 444 best_acc_val = 0.0
444 445
445 lrs = []
446 losses = []
447 accs = []
448
449 local_progress_bar = tqdm( 446 local_progress_bar = tqdm(
450 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 447 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
451 disable=not accelerator.is_local_main_process, 448 disable=not accelerator.is_local_main_process,
@@ -496,6 +493,8 @@ def train_loop(
496 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 493 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
497 local_progress_bar.reset() 494 local_progress_bar.reset()
498 495
496 logs = {}
497
499 with on_train(epoch): 498 with on_train(epoch):
500 for step, batch in enumerate(train_dataloader): 499 for step, batch in enumerate(train_dataloader):
501 loss, acc, bsz = loss_step(step, batch, cache) 500 loss, acc, bsz = loss_step(step, batch, cache)
@@ -506,31 +505,36 @@ def train_loop(
506 avg_loss.update(loss.detach_(), bsz) 505 avg_loss.update(loss.detach_(), bsz)
507 avg_acc.update(acc.detach_(), bsz) 506 avg_acc.update(acc.detach_(), bsz)
508 507
509 lr = lr_scheduler.get_last_lr()[0]
510 if torch.is_tensor(lr):
511 lr = lr.item()
512
513 logs = { 508 logs = {
514 "train/loss": avg_loss.avg.item(), 509 "train/loss": avg_loss.avg.item(),
515 "train/acc": avg_acc.avg.item(), 510 "train/acc": avg_acc.avg.item(),
516 "train/cur_loss": loss.item(), 511 "train/cur_loss": loss.item(),
517 "train/cur_acc": acc.item(), 512 "train/cur_acc": acc.item(),
518 "lr": lr,
519 } 513 }
520 if isDadaptation: 514
521 logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] 515 lrs: dict[str, float] = {}
516 for i, lr in enumerate(lr_scheduler.get_last_lr()):
517 if torch.is_tensor(lr):
518 lr = lr.item()
519 label = group_labels[i] if i < len(group_labels) else f"{i}"
520 logs[f"lr/{label}"] = lr
521 if isDadaptation:
522 lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
523 logs[f"d*lr/{label}"] = lr
524 lrs[label] = lr
525
522 logs.update(on_log()) 526 logs.update(on_log())
523 527
524 local_progress_bar.set_postfix(**logs) 528 local_progress_bar.set_postfix(**logs)
525 529
526 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 530 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
527 before_optimize_result = on_before_optimize(lr, epoch) 531 before_optimize_result = on_before_optimize(epoch)
528 532
529 optimizer.step() 533 optimizer.step()
530 lr_scheduler.step() 534 lr_scheduler.step()
531 optimizer.zero_grad(set_to_none=True) 535 optimizer.zero_grad(set_to_none=True)
532 536
533 on_after_optimize(before_optimize_result, lr) 537 on_after_optimize(before_optimize_result, lrs)
534 538
535 local_progress_bar.update(1) 539 local_progress_bar.update(1)
536 global_progress_bar.update(1) 540 global_progress_bar.update(1)
@@ -544,15 +548,6 @@ def train_loop(
544 548
545 accelerator.wait_for_everyone() 549 accelerator.wait_for_everyone()
546 550
547 if isDadaptation:
548 lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
549 else:
550 lr = lr_scheduler.get_last_lr()[0]
551 if torch.is_tensor(lr):
552 lr = lr.item()
553
554 lrs.append(lr)
555
556 on_after_epoch() 551 on_after_epoch()
557 552
558 if val_dataloader is not None: 553 if val_dataloader is not None:
@@ -597,9 +592,6 @@ def train_loop(
597 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 592 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
598 on_checkpoint(global_step + global_step_offset, "milestone") 593 on_checkpoint(global_step + global_step_offset, "milestone")
599 best_acc_val = avg_acc_val.avg.item() 594 best_acc_val = avg_acc_val.avg.item()
600
601 losses.append(avg_loss_val.avg.item())
602 accs.append(avg_acc_val.avg.item())
603 else: 595 else:
604 if accelerator.is_main_process: 596 if accelerator.is_main_process:
605 if avg_acc.avg.item() > best_acc and milestone_checkpoints: 597 if avg_acc.avg.item() > best_acc and milestone_checkpoints:
@@ -611,9 +603,6 @@ def train_loop(
611 on_checkpoint(global_step + global_step_offset, "milestone") 603 on_checkpoint(global_step + global_step_offset, "milestone")
612 best_acc = avg_acc.avg.item() 604 best_acc = avg_acc.avg.item()
613 605
614 losses.append(avg_loss.avg.item())
615 accs.append(avg_acc.avg.item())
616
617 # Create the pipeline using using the trained modules and save it. 606 # Create the pipeline using using the trained modules and save it.
618 if accelerator.is_main_process: 607 if accelerator.is_main_process:
619 print("Finished!") 608 print("Finished!")
@@ -626,8 +615,6 @@ def train_loop(
626 on_checkpoint(global_step + global_step_offset, "end") 615 on_checkpoint(global_step + global_step_offset, "end")
627 raise KeyboardInterrupt 616 raise KeyboardInterrupt
628 617
629 return lrs, losses, accs
630
631 618
632def train( 619def train(
633 accelerator: Accelerator, 620 accelerator: Accelerator,
@@ -646,6 +633,7 @@ def train(
646 no_val: bool = False, 633 no_val: bool = False,
647 num_train_epochs: int = 100, 634 num_train_epochs: int = 100,
648 gradient_accumulation_steps: int = 1, 635 gradient_accumulation_steps: int = 1,
636 group_labels: list[str] = [],
649 sample_frequency: int = 20, 637 sample_frequency: int = 20,
650 checkpoint_frequency: int = 50, 638 checkpoint_frequency: int = 50,
651 milestone_checkpoints: bool = True, 639 milestone_checkpoints: bool = True,
@@ -692,7 +680,7 @@ def train(
692 if accelerator.is_main_process: 680 if accelerator.is_main_process:
693 accelerator.init_trackers(project) 681 accelerator.init_trackers(project)
694 682
695 metrics = train_loop( 683 train_loop(
696 accelerator=accelerator, 684 accelerator=accelerator,
697 optimizer=optimizer, 685 optimizer=optimizer,
698 lr_scheduler=lr_scheduler, 686 lr_scheduler=lr_scheduler,
@@ -705,10 +693,9 @@ def train(
705 global_step_offset=global_step_offset, 693 global_step_offset=global_step_offset,
706 num_epochs=num_train_epochs, 694 num_epochs=num_train_epochs,
707 gradient_accumulation_steps=gradient_accumulation_steps, 695 gradient_accumulation_steps=gradient_accumulation_steps,
696 group_labels=group_labels,
708 callbacks=callbacks, 697 callbacks=callbacks,
709 ) 698 )
710 699
711 accelerator.end_training() 700 accelerator.end_training()
712 accelerator.free_memory() 701 accelerator.free_memory()
713
714 return metrics
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0286673..695174a 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 106 with ema_context():
107 yield 107 yield
108 108
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(epoch: int):
110 params_to_clip = [unet.parameters()] 110 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 111 if epoch < train_text_encoder_epochs:
112 params_to_clip.append(text_encoder.parameters()) 112 params_to_clip.append(text_encoder.parameters())
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 912ff26..89269c0 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -79,10 +79,14 @@ def lora_strategy_callbacks(
79 tokenizer.eval() 79 tokenizer.eval()
80 yield 80 yield
81 81
82 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(epoch: int):
83 if not pti_mode: 83 if not pti_mode:
84 accelerator.clip_grad_norm_( 84 accelerator.clip_grad_norm_(
85 itertools.chain(unet.parameters(), text_encoder.parameters()), 85 itertools.chain(
86 unet.parameters(),
87 text_encoder.text_model.encoder.parameters(),
88 text_encoder.text_model.final_layer_norm.parameters(),
89 ),
86 max_grad_norm 90 max_grad_norm
87 ) 91 )
88 92
@@ -95,7 +99,9 @@ def lora_strategy_callbacks(
95 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
96 100
97 @torch.no_grad() 101 @torch.no_grad()
98 def on_after_optimize(w, lr: float): 102 def on_after_optimize(w, lrs: dict[str, float]):
103 lr = lrs["emb"] or lrs["0"]
104
99 if use_emb_decay and w is not None: 105 if use_emb_decay and w is not None:
100 lambda_ = emb_decay * lr 106 lambda_ = emb_decay * lr
101 107
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6a637c3..d735dac 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(lr: float, epoch: int): 107 def on_before_optimize(epoch: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p