diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 57 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/lora.py | 12 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
4 files changed, 33 insertions, 40 deletions
diff --git a/training/functional.py b/training/functional.py index 4d83df1..71b2fe9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -36,8 +36,8 @@ def const(result=None): | |||
36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
39 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[int], Any] = const() |
40 | on_after_optimize: Callable[[Any, float], None] = const() | 40 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
41 | on_after_epoch: Callable[[], None] = const() | 41 | on_after_epoch: Callable[[], None] = const() |
42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
43 | on_sample: Callable[[int], None] = const() | 43 | on_sample: Callable[[int], None] = const() |
@@ -422,6 +422,7 @@ def train_loop( | |||
422 | global_step_offset: int = 0, | 422 | global_step_offset: int = 0, |
423 | num_epochs: int = 100, | 423 | num_epochs: int = 100, |
424 | gradient_accumulation_steps: int = 1, | 424 | gradient_accumulation_steps: int = 1, |
425 | group_labels: list[str] = [], | ||
425 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 426 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
426 | ): | 427 | ): |
427 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 428 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
@@ -442,10 +443,6 @@ def train_loop( | |||
442 | best_acc = 0.0 | 443 | best_acc = 0.0 |
443 | best_acc_val = 0.0 | 444 | best_acc_val = 0.0 |
444 | 445 | ||
445 | lrs = [] | ||
446 | losses = [] | ||
447 | accs = [] | ||
448 | |||
449 | local_progress_bar = tqdm( | 446 | local_progress_bar = tqdm( |
450 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 447 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
451 | disable=not accelerator.is_local_main_process, | 448 | disable=not accelerator.is_local_main_process, |
@@ -496,6 +493,8 @@ def train_loop( | |||
496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 493 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
497 | local_progress_bar.reset() | 494 | local_progress_bar.reset() |
498 | 495 | ||
496 | logs = {} | ||
497 | |||
499 | with on_train(epoch): | 498 | with on_train(epoch): |
500 | for step, batch in enumerate(train_dataloader): | 499 | for step, batch in enumerate(train_dataloader): |
501 | loss, acc, bsz = loss_step(step, batch, cache) | 500 | loss, acc, bsz = loss_step(step, batch, cache) |
@@ -506,31 +505,36 @@ def train_loop( | |||
506 | avg_loss.update(loss.detach_(), bsz) | 505 | avg_loss.update(loss.detach_(), bsz) |
507 | avg_acc.update(acc.detach_(), bsz) | 506 | avg_acc.update(acc.detach_(), bsz) |
508 | 507 | ||
509 | lr = lr_scheduler.get_last_lr()[0] | ||
510 | if torch.is_tensor(lr): | ||
511 | lr = lr.item() | ||
512 | |||
513 | logs = { | 508 | logs = { |
514 | "train/loss": avg_loss.avg.item(), | 509 | "train/loss": avg_loss.avg.item(), |
515 | "train/acc": avg_acc.avg.item(), | 510 | "train/acc": avg_acc.avg.item(), |
516 | "train/cur_loss": loss.item(), | 511 | "train/cur_loss": loss.item(), |
517 | "train/cur_acc": acc.item(), | 512 | "train/cur_acc": acc.item(), |
518 | "lr": lr, | ||
519 | } | 513 | } |
520 | if isDadaptation: | 514 | |
521 | logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 515 | lrs: dict[str, float] = {} |
516 | for i, lr in enumerate(lr_scheduler.get_last_lr()): | ||
517 | if torch.is_tensor(lr): | ||
518 | lr = lr.item() | ||
519 | label = group_labels[i] if i < len(group_labels) else f"{i}" | ||
520 | logs[f"lr/{label}"] = lr | ||
521 | if isDadaptation: | ||
522 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | ||
523 | logs[f"d*lr/{label}"] = lr | ||
524 | lrs[label] = lr | ||
525 | |||
522 | logs.update(on_log()) | 526 | logs.update(on_log()) |
523 | 527 | ||
524 | local_progress_bar.set_postfix(**logs) | 528 | local_progress_bar.set_postfix(**logs) |
525 | 529 | ||
526 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 530 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
527 | before_optimize_result = on_before_optimize(lr, epoch) | 531 | before_optimize_result = on_before_optimize(epoch) |
528 | 532 | ||
529 | optimizer.step() | 533 | optimizer.step() |
530 | lr_scheduler.step() | 534 | lr_scheduler.step() |
531 | optimizer.zero_grad(set_to_none=True) | 535 | optimizer.zero_grad(set_to_none=True) |
532 | 536 | ||
533 | on_after_optimize(before_optimize_result, lr) | 537 | on_after_optimize(before_optimize_result, lrs) |
534 | 538 | ||
535 | local_progress_bar.update(1) | 539 | local_progress_bar.update(1) |
536 | global_progress_bar.update(1) | 540 | global_progress_bar.update(1) |
@@ -544,15 +548,6 @@ def train_loop( | |||
544 | 548 | ||
545 | accelerator.wait_for_everyone() | 549 | accelerator.wait_for_everyone() |
546 | 550 | ||
547 | if isDadaptation: | ||
548 | lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | ||
549 | else: | ||
550 | lr = lr_scheduler.get_last_lr()[0] | ||
551 | if torch.is_tensor(lr): | ||
552 | lr = lr.item() | ||
553 | |||
554 | lrs.append(lr) | ||
555 | |||
556 | on_after_epoch() | 551 | on_after_epoch() |
557 | 552 | ||
558 | if val_dataloader is not None: | 553 | if val_dataloader is not None: |
@@ -597,9 +592,6 @@ def train_loop( | |||
597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 592 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
598 | on_checkpoint(global_step + global_step_offset, "milestone") | 593 | on_checkpoint(global_step + global_step_offset, "milestone") |
599 | best_acc_val = avg_acc_val.avg.item() | 594 | best_acc_val = avg_acc_val.avg.item() |
600 | |||
601 | losses.append(avg_loss_val.avg.item()) | ||
602 | accs.append(avg_acc_val.avg.item()) | ||
603 | else: | 595 | else: |
604 | if accelerator.is_main_process: | 596 | if accelerator.is_main_process: |
605 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 597 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
@@ -611,9 +603,6 @@ def train_loop( | |||
611 | on_checkpoint(global_step + global_step_offset, "milestone") | 603 | on_checkpoint(global_step + global_step_offset, "milestone") |
612 | best_acc = avg_acc.avg.item() | 604 | best_acc = avg_acc.avg.item() |
613 | 605 | ||
614 | losses.append(avg_loss.avg.item()) | ||
615 | accs.append(avg_acc.avg.item()) | ||
616 | |||
617 | # Create the pipeline using using the trained modules and save it. | 606 | # Create the pipeline using using the trained modules and save it. |
618 | if accelerator.is_main_process: | 607 | if accelerator.is_main_process: |
619 | print("Finished!") | 608 | print("Finished!") |
@@ -626,8 +615,6 @@ def train_loop( | |||
626 | on_checkpoint(global_step + global_step_offset, "end") | 615 | on_checkpoint(global_step + global_step_offset, "end") |
627 | raise KeyboardInterrupt | 616 | raise KeyboardInterrupt |
628 | 617 | ||
629 | return lrs, losses, accs | ||
630 | |||
631 | 618 | ||
632 | def train( | 619 | def train( |
633 | accelerator: Accelerator, | 620 | accelerator: Accelerator, |
@@ -646,6 +633,7 @@ def train( | |||
646 | no_val: bool = False, | 633 | no_val: bool = False, |
647 | num_train_epochs: int = 100, | 634 | num_train_epochs: int = 100, |
648 | gradient_accumulation_steps: int = 1, | 635 | gradient_accumulation_steps: int = 1, |
636 | group_labels: list[str] = [], | ||
649 | sample_frequency: int = 20, | 637 | sample_frequency: int = 20, |
650 | checkpoint_frequency: int = 50, | 638 | checkpoint_frequency: int = 50, |
651 | milestone_checkpoints: bool = True, | 639 | milestone_checkpoints: bool = True, |
@@ -692,7 +680,7 @@ def train( | |||
692 | if accelerator.is_main_process: | 680 | if accelerator.is_main_process: |
693 | accelerator.init_trackers(project) | 681 | accelerator.init_trackers(project) |
694 | 682 | ||
695 | metrics = train_loop( | 683 | train_loop( |
696 | accelerator=accelerator, | 684 | accelerator=accelerator, |
697 | optimizer=optimizer, | 685 | optimizer=optimizer, |
698 | lr_scheduler=lr_scheduler, | 686 | lr_scheduler=lr_scheduler, |
@@ -705,10 +693,9 @@ def train( | |||
705 | global_step_offset=global_step_offset, | 693 | global_step_offset=global_step_offset, |
706 | num_epochs=num_train_epochs, | 694 | num_epochs=num_train_epochs, |
707 | gradient_accumulation_steps=gradient_accumulation_steps, | 695 | gradient_accumulation_steps=gradient_accumulation_steps, |
696 | group_labels=group_labels, | ||
708 | callbacks=callbacks, | 697 | callbacks=callbacks, |
709 | ) | 698 | ) |
710 | 699 | ||
711 | accelerator.end_training() | 700 | accelerator.end_training() |
712 | accelerator.free_memory() | 701 | accelerator.free_memory() |
713 | |||
714 | return metrics | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
106 | with ema_context(): | 106 | with ema_context(): |
107 | yield | 107 | yield |
108 | 108 | ||
109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
79 | tokenizer.eval() | 79 | tokenizer.eval() |
80 | yield | 80 | yield |
81 | 81 | ||
82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
83 | if not pti_mode: | 83 | if not pti_mode: |
84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
86 | unet.parameters(), | ||
87 | text_encoder.text_model.encoder.parameters(), | ||
88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
89 | ), | ||
86 | max_grad_norm | 90 | max_grad_norm |
87 | ) | 91 | ) |
88 | 92 | ||
@@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
96 | 100 | ||
97 | @torch.no_grad() | 101 | @torch.no_grad() |
98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
103 | lr = lrs["emb"] or lrs["0"] | ||
104 | |||
99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
101 | 107 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(lr: float, epoch: int): | 107 | def on_before_optimize(epoch: int): |
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |