diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
| commit | 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch) | |
| tree | b1483a52fb853aecb7b73635cded3cce61edf125 /training | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2 textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip | |
Update
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 57 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/lora.py | 12 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
4 files changed, 33 insertions, 40 deletions
diff --git a/training/functional.py b/training/functional.py index 4d83df1..71b2fe9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -36,8 +36,8 @@ def const(result=None): | |||
| 36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
| 37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 39 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[int], Any] = const() |
| 40 | on_after_optimize: Callable[[Any, float], None] = const() | 40 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
| 41 | on_after_epoch: Callable[[], None] = const() | 41 | on_after_epoch: Callable[[], None] = const() |
| 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| 43 | on_sample: Callable[[int], None] = const() | 43 | on_sample: Callable[[int], None] = const() |
| @@ -422,6 +422,7 @@ def train_loop( | |||
| 422 | global_step_offset: int = 0, | 422 | global_step_offset: int = 0, |
| 423 | num_epochs: int = 100, | 423 | num_epochs: int = 100, |
| 424 | gradient_accumulation_steps: int = 1, | 424 | gradient_accumulation_steps: int = 1, |
| 425 | group_labels: list[str] = [], | ||
| 425 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 426 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 426 | ): | 427 | ): |
| 427 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 428 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| @@ -442,10 +443,6 @@ def train_loop( | |||
| 442 | best_acc = 0.0 | 443 | best_acc = 0.0 |
| 443 | best_acc_val = 0.0 | 444 | best_acc_val = 0.0 |
| 444 | 445 | ||
| 445 | lrs = [] | ||
| 446 | losses = [] | ||
| 447 | accs = [] | ||
| 448 | |||
| 449 | local_progress_bar = tqdm( | 446 | local_progress_bar = tqdm( |
| 450 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 447 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| 451 | disable=not accelerator.is_local_main_process, | 448 | disable=not accelerator.is_local_main_process, |
| @@ -496,6 +493,8 @@ def train_loop( | |||
| 496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 493 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 497 | local_progress_bar.reset() | 494 | local_progress_bar.reset() |
| 498 | 495 | ||
| 496 | logs = {} | ||
| 497 | |||
| 499 | with on_train(epoch): | 498 | with on_train(epoch): |
| 500 | for step, batch in enumerate(train_dataloader): | 499 | for step, batch in enumerate(train_dataloader): |
| 501 | loss, acc, bsz = loss_step(step, batch, cache) | 500 | loss, acc, bsz = loss_step(step, batch, cache) |
| @@ -506,31 +505,36 @@ def train_loop( | |||
| 506 | avg_loss.update(loss.detach_(), bsz) | 505 | avg_loss.update(loss.detach_(), bsz) |
| 507 | avg_acc.update(acc.detach_(), bsz) | 506 | avg_acc.update(acc.detach_(), bsz) |
| 508 | 507 | ||
| 509 | lr = lr_scheduler.get_last_lr()[0] | ||
| 510 | if torch.is_tensor(lr): | ||
| 511 | lr = lr.item() | ||
| 512 | |||
| 513 | logs = { | 508 | logs = { |
| 514 | "train/loss": avg_loss.avg.item(), | 509 | "train/loss": avg_loss.avg.item(), |
| 515 | "train/acc": avg_acc.avg.item(), | 510 | "train/acc": avg_acc.avg.item(), |
| 516 | "train/cur_loss": loss.item(), | 511 | "train/cur_loss": loss.item(), |
| 517 | "train/cur_acc": acc.item(), | 512 | "train/cur_acc": acc.item(), |
| 518 | "lr": lr, | ||
| 519 | } | 513 | } |
| 520 | if isDadaptation: | 514 | |
| 521 | logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 515 | lrs: dict[str, float] = {} |
| 516 | for i, lr in enumerate(lr_scheduler.get_last_lr()): | ||
| 517 | if torch.is_tensor(lr): | ||
| 518 | lr = lr.item() | ||
| 519 | label = group_labels[i] if i < len(group_labels) else f"{i}" | ||
| 520 | logs[f"lr/{label}"] = lr | ||
| 521 | if isDadaptation: | ||
| 522 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | ||
| 523 | logs[f"d*lr/{label}"] = lr | ||
| 524 | lrs[label] = lr | ||
| 525 | |||
| 522 | logs.update(on_log()) | 526 | logs.update(on_log()) |
| 523 | 527 | ||
| 524 | local_progress_bar.set_postfix(**logs) | 528 | local_progress_bar.set_postfix(**logs) |
| 525 | 529 | ||
| 526 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 530 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
| 527 | before_optimize_result = on_before_optimize(lr, epoch) | 531 | before_optimize_result = on_before_optimize(epoch) |
| 528 | 532 | ||
| 529 | optimizer.step() | 533 | optimizer.step() |
| 530 | lr_scheduler.step() | 534 | lr_scheduler.step() |
| 531 | optimizer.zero_grad(set_to_none=True) | 535 | optimizer.zero_grad(set_to_none=True) |
| 532 | 536 | ||
| 533 | on_after_optimize(before_optimize_result, lr) | 537 | on_after_optimize(before_optimize_result, lrs) |
| 534 | 538 | ||
| 535 | local_progress_bar.update(1) | 539 | local_progress_bar.update(1) |
| 536 | global_progress_bar.update(1) | 540 | global_progress_bar.update(1) |
| @@ -544,15 +548,6 @@ def train_loop( | |||
| 544 | 548 | ||
| 545 | accelerator.wait_for_everyone() | 549 | accelerator.wait_for_everyone() |
| 546 | 550 | ||
| 547 | if isDadaptation: | ||
| 548 | lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | ||
| 549 | else: | ||
| 550 | lr = lr_scheduler.get_last_lr()[0] | ||
| 551 | if torch.is_tensor(lr): | ||
| 552 | lr = lr.item() | ||
| 553 | |||
| 554 | lrs.append(lr) | ||
| 555 | |||
| 556 | on_after_epoch() | 551 | on_after_epoch() |
| 557 | 552 | ||
| 558 | if val_dataloader is not None: | 553 | if val_dataloader is not None: |
| @@ -597,9 +592,6 @@ def train_loop( | |||
| 597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 592 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 598 | on_checkpoint(global_step + global_step_offset, "milestone") | 593 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 599 | best_acc_val = avg_acc_val.avg.item() | 594 | best_acc_val = avg_acc_val.avg.item() |
| 600 | |||
| 601 | losses.append(avg_loss_val.avg.item()) | ||
| 602 | accs.append(avg_acc_val.avg.item()) | ||
| 603 | else: | 595 | else: |
| 604 | if accelerator.is_main_process: | 596 | if accelerator.is_main_process: |
| 605 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 597 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
| @@ -611,9 +603,6 @@ def train_loop( | |||
| 611 | on_checkpoint(global_step + global_step_offset, "milestone") | 603 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 612 | best_acc = avg_acc.avg.item() | 604 | best_acc = avg_acc.avg.item() |
| 613 | 605 | ||
| 614 | losses.append(avg_loss.avg.item()) | ||
| 615 | accs.append(avg_acc.avg.item()) | ||
| 616 | |||
| 617 | # Create the pipeline using using the trained modules and save it. | 606 | # Create the pipeline using using the trained modules and save it. |
| 618 | if accelerator.is_main_process: | 607 | if accelerator.is_main_process: |
| 619 | print("Finished!") | 608 | print("Finished!") |
| @@ -626,8 +615,6 @@ def train_loop( | |||
| 626 | on_checkpoint(global_step + global_step_offset, "end") | 615 | on_checkpoint(global_step + global_step_offset, "end") |
| 627 | raise KeyboardInterrupt | 616 | raise KeyboardInterrupt |
| 628 | 617 | ||
| 629 | return lrs, losses, accs | ||
| 630 | |||
| 631 | 618 | ||
| 632 | def train( | 619 | def train( |
| 633 | accelerator: Accelerator, | 620 | accelerator: Accelerator, |
| @@ -646,6 +633,7 @@ def train( | |||
| 646 | no_val: bool = False, | 633 | no_val: bool = False, |
| 647 | num_train_epochs: int = 100, | 634 | num_train_epochs: int = 100, |
| 648 | gradient_accumulation_steps: int = 1, | 635 | gradient_accumulation_steps: int = 1, |
| 636 | group_labels: list[str] = [], | ||
| 649 | sample_frequency: int = 20, | 637 | sample_frequency: int = 20, |
| 650 | checkpoint_frequency: int = 50, | 638 | checkpoint_frequency: int = 50, |
| 651 | milestone_checkpoints: bool = True, | 639 | milestone_checkpoints: bool = True, |
| @@ -692,7 +680,7 @@ def train( | |||
| 692 | if accelerator.is_main_process: | 680 | if accelerator.is_main_process: |
| 693 | accelerator.init_trackers(project) | 681 | accelerator.init_trackers(project) |
| 694 | 682 | ||
| 695 | metrics = train_loop( | 683 | train_loop( |
| 696 | accelerator=accelerator, | 684 | accelerator=accelerator, |
| 697 | optimizer=optimizer, | 685 | optimizer=optimizer, |
| 698 | lr_scheduler=lr_scheduler, | 686 | lr_scheduler=lr_scheduler, |
| @@ -705,10 +693,9 @@ def train( | |||
| 705 | global_step_offset=global_step_offset, | 693 | global_step_offset=global_step_offset, |
| 706 | num_epochs=num_train_epochs, | 694 | num_epochs=num_train_epochs, |
| 707 | gradient_accumulation_steps=gradient_accumulation_steps, | 695 | gradient_accumulation_steps=gradient_accumulation_steps, |
| 696 | group_labels=group_labels, | ||
| 708 | callbacks=callbacks, | 697 | callbacks=callbacks, |
| 709 | ) | 698 | ) |
| 710 | 699 | ||
| 711 | accelerator.end_training() | 700 | accelerator.end_training() |
| 712 | accelerator.free_memory() | 701 | accelerator.free_memory() |
| 713 | |||
| 714 | return metrics | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 106 | with ema_context(): |
| 107 | yield | 107 | yield |
| 108 | 108 | ||
| 109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
| 110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
| 79 | tokenizer.eval() | 79 | tokenizer.eval() |
| 80 | yield | 80 | yield |
| 81 | 81 | ||
| 82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
| 83 | if not pti_mode: | 83 | if not pti_mode: |
| 84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
| 85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
| 86 | unet.parameters(), | ||
| 87 | text_encoder.text_model.encoder.parameters(), | ||
| 88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 89 | ), | ||
| 86 | max_grad_norm | 90 | max_grad_norm |
| 87 | ) | 91 | ) |
| 88 | 92 | ||
| @@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
| 95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
| 96 | 100 | ||
| 97 | @torch.no_grad() | 101 | @torch.no_grad() |
| 98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
| 103 | lr = lrs["emb"] or lrs["0"] | ||
| 104 | |||
| 99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
| 100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
| 101 | 107 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(lr: float, epoch: int): | 107 | def on_before_optimize(epoch: int): |
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
