diff options
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 10 | ||||
| -rw-r--r-- | train_ti.py | 21 | ||||
| -rw-r--r-- | training/functional.py | 35 | ||||
| -rw-r--r-- | training/lr.py | 256 | ||||
| -rw-r--r-- | training/optimization.py | 19 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 5 | ||||
| -rw-r--r-- | training/util.py | 146 |
9 files changed, 106 insertions, 392 deletions
diff --git a/environment.yaml b/environment.yaml index 03345c6..c992759 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -25,4 +25,4 @@ dependencies: | |||
| 25 | - test-tube>=0.7.5 | 25 | - test-tube>=0.7.5 |
| 26 | - transformers==4.25.1 | 26 | - transformers==4.25.1 |
| 27 | - triton==2.0.0.dev20221202 | 27 | - triton==2.0.0.dev20221202 |
| 28 | - xformers==0.0.16rc403 | 28 | - xformers==0.0.16.dev430 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9c1e41c..a70c80e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -16,6 +16,7 @@ from slugify import slugify | |||
| 16 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
| 18 | from training.functional import train, get_models | 18 | from training.functional import train, get_models |
| 19 | from training.lr import plot_metrics | ||
| 19 | from training.strategy.dreambooth import dreambooth_strategy | 20 | from training.strategy.dreambooth import dreambooth_strategy |
| 20 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
| 21 | from training.util import save_args | 22 | from training.util import save_args |
| @@ -524,6 +525,10 @@ def main(): | |||
| 524 | args.train_batch_size * accelerator.num_processes | 525 | args.train_batch_size * accelerator.num_processes |
| 525 | ) | 526 | ) |
| 526 | 527 | ||
| 528 | if args.find_lr: | ||
| 529 | args.learning_rate = 1e-6 | ||
| 530 | args.lr_scheduler = "exponential_growth" | ||
| 531 | |||
| 527 | if args.use_8bit_adam: | 532 | if args.use_8bit_adam: |
| 528 | try: | 533 | try: |
| 529 | import bitsandbytes as bnb | 534 | import bitsandbytes as bnb |
| @@ -602,11 +607,12 @@ def main(): | |||
| 602 | warmup_exp=args.lr_warmup_exp, | 607 | warmup_exp=args.lr_warmup_exp, |
| 603 | annealing_exp=args.lr_annealing_exp, | 608 | annealing_exp=args.lr_annealing_exp, |
| 604 | cycles=args.lr_cycles, | 609 | cycles=args.lr_cycles, |
| 610 | end_lr=1e2, | ||
| 605 | train_epochs=args.num_train_epochs, | 611 | train_epochs=args.num_train_epochs, |
| 606 | warmup_epochs=args.lr_warmup_epochs, | 612 | warmup_epochs=args.lr_warmup_epochs, |
| 607 | ) | 613 | ) |
| 608 | 614 | ||
| 609 | trainer( | 615 | metrics = trainer( |
| 610 | strategy=dreambooth_strategy, | 616 | strategy=dreambooth_strategy, |
| 611 | project="dreambooth", | 617 | project="dreambooth", |
| 612 | train_dataloader=datamodule.train_dataloader, | 618 | train_dataloader=datamodule.train_dataloader, |
| @@ -634,6 +640,8 @@ def main(): | |||
| 634 | sample_image_size=args.sample_image_size, | 640 | sample_image_size=args.sample_image_size, |
| 635 | ) | 641 | ) |
| 636 | 642 | ||
| 643 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
| 644 | |||
| 637 | 645 | ||
| 638 | if __name__ == "__main__": | 646 | if __name__ == "__main__": |
| 639 | main() | 647 | main() |
diff --git a/train_ti.py b/train_ti.py index 451b61b..c118aab 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -15,6 +15,7 @@ from slugify import slugify | |||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
| 18 | from training.lr import plot_metrics | ||
| 18 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
| 19 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
| 20 | from training.util import save_args | 21 | from training.util import save_args |
| @@ -61,6 +62,12 @@ def parse_args(): | |||
| 61 | help="The name of the current project.", | 62 | help="The name of the current project.", |
| 62 | ) | 63 | ) |
| 63 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--skip_first", | ||
| 66 | type=int, | ||
| 67 | default=0, | ||
| 68 | help="Tokens to skip training for.", | ||
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 64 | "--placeholder_tokens", | 71 | "--placeholder_tokens", |
| 65 | type=str, | 72 | type=str, |
| 66 | nargs='*', | 73 | nargs='*', |
| @@ -407,7 +414,7 @@ def parse_args(): | |||
| 407 | ) | 414 | ) |
| 408 | parser.add_argument( | 415 | parser.add_argument( |
| 409 | "--emb_decay", | 416 | "--emb_decay", |
| 410 | default=10, | 417 | default=1e0, |
| 411 | type=float, | 418 | type=float, |
| 412 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
| 413 | ) | 420 | ) |
| @@ -543,6 +550,10 @@ def main(): | |||
| 543 | args.train_batch_size * accelerator.num_processes | 550 | args.train_batch_size * accelerator.num_processes |
| 544 | ) | 551 | ) |
| 545 | 552 | ||
| 553 | if args.find_lr: | ||
| 554 | args.learning_rate = 1e-5 | ||
| 555 | args.lr_scheduler = "exponential_growth" | ||
| 556 | |||
| 546 | if args.use_8bit_adam: | 557 | if args.use_8bit_adam: |
| 547 | try: | 558 | try: |
| 548 | import bitsandbytes as bnb | 559 | import bitsandbytes as bnb |
| @@ -596,6 +607,9 @@ def main(): | |||
| 596 | ) | 607 | ) |
| 597 | 608 | ||
| 598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 609 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 610 | if i < args.skip_first: | ||
| 611 | return | ||
| 612 | |||
| 599 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
| 600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
| 601 | else: | 615 | else: |
| @@ -656,11 +670,12 @@ def main(): | |||
| 656 | warmup_exp=args.lr_warmup_exp, | 670 | warmup_exp=args.lr_warmup_exp, |
| 657 | annealing_exp=args.lr_annealing_exp, | 671 | annealing_exp=args.lr_annealing_exp, |
| 658 | cycles=args.lr_cycles, | 672 | cycles=args.lr_cycles, |
| 673 | end_lr=1e3, | ||
| 659 | train_epochs=args.num_train_epochs, | 674 | train_epochs=args.num_train_epochs, |
| 660 | warmup_epochs=args.lr_warmup_epochs, | 675 | warmup_epochs=args.lr_warmup_epochs, |
| 661 | ) | 676 | ) |
| 662 | 677 | ||
| 663 | trainer( | 678 | metrics = trainer( |
| 664 | project="textual_inversion", | 679 | project="textual_inversion", |
| 665 | train_dataloader=datamodule.train_dataloader, | 680 | train_dataloader=datamodule.train_dataloader, |
| 666 | val_dataloader=datamodule.val_dataloader, | 681 | val_dataloader=datamodule.val_dataloader, |
| @@ -672,6 +687,8 @@ def main(): | |||
| 672 | placeholder_token_ids=placeholder_token_ids, | 687 | placeholder_token_ids=placeholder_token_ids, |
| 673 | ) | 688 | ) |
| 674 | 689 | ||
| 690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
| 691 | |||
| 675 | if args.simultaneous: | 692 | if args.simultaneous: |
| 676 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 677 | else: | 694 | else: |
diff --git a/training/functional.py b/training/functional.py index fb135c4..c373ac9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -7,7 +7,6 @@ from pathlib import Path | |||
| 7 | import itertools | 7 | import itertools |
| 8 | 8 | ||
| 9 | import torch | 9 | import torch |
| 10 | import torch.nn as nn | ||
| 11 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
| 12 | from torch.utils.data import DataLoader | 11 | from torch.utils.data import DataLoader |
| 13 | 12 | ||
| @@ -373,8 +372,12 @@ def train_loop( | |||
| 373 | avg_loss_val = AverageMeter() | 372 | avg_loss_val = AverageMeter() |
| 374 | avg_acc_val = AverageMeter() | 373 | avg_acc_val = AverageMeter() |
| 375 | 374 | ||
| 376 | max_acc = 0.0 | 375 | best_acc = 0.0 |
| 377 | max_acc_val = 0.0 | 376 | best_acc_val = 0.0 |
| 377 | |||
| 378 | lrs = [] | ||
| 379 | losses = [] | ||
| 380 | accs = [] | ||
| 378 | 381 | ||
| 379 | local_progress_bar = tqdm( | 382 | local_progress_bar = tqdm( |
| 380 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 383 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -457,6 +460,8 @@ def train_loop( | |||
| 457 | 460 | ||
| 458 | accelerator.wait_for_everyone() | 461 | accelerator.wait_for_everyone() |
| 459 | 462 | ||
| 463 | lrs.append(lr_scheduler.get_last_lr()[0]) | ||
| 464 | |||
| 460 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 465 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
| 461 | 466 | ||
| 462 | if val_dataloader is not None: | 467 | if val_dataloader is not None: |
| @@ -498,18 +503,24 @@ def train_loop( | |||
| 498 | global_progress_bar.clear() | 503 | global_progress_bar.clear() |
| 499 | 504 | ||
| 500 | if accelerator.is_main_process: | 505 | if accelerator.is_main_process: |
| 501 | if avg_acc_val.avg.item() > max_acc_val: | 506 | if avg_acc_val.avg.item() > best_acc_val: |
| 502 | accelerator.print( | 507 | accelerator.print( |
| 503 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 508 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 504 | on_checkpoint(global_step + global_step_offset, "milestone") | 509 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 505 | max_acc_val = avg_acc_val.avg.item() | 510 | best_acc_val = avg_acc_val.avg.item() |
| 511 | |||
| 512 | losses.append(avg_loss_val.avg.item()) | ||
| 513 | accs.append(avg_acc_val.avg.item()) | ||
| 506 | else: | 514 | else: |
| 507 | if accelerator.is_main_process: | 515 | if accelerator.is_main_process: |
| 508 | if avg_acc.avg.item() > max_acc: | 516 | if avg_acc.avg.item() > best_acc: |
| 509 | accelerator.print( | 517 | accelerator.print( |
| 510 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | 518 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
| 511 | on_checkpoint(global_step + global_step_offset, "milestone") | 519 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 512 | max_acc = avg_acc.avg.item() | 520 | best_acc = avg_acc.avg.item() |
| 521 | |||
| 522 | losses.append(avg_loss.avg.item()) | ||
| 523 | accs.append(avg_acc.avg.item()) | ||
| 513 | 524 | ||
| 514 | # Create the pipeline using using the trained modules and save it. | 525 | # Create the pipeline using using the trained modules and save it. |
| 515 | if accelerator.is_main_process: | 526 | if accelerator.is_main_process: |
| @@ -523,6 +534,8 @@ def train_loop( | |||
| 523 | on_checkpoint(global_step + global_step_offset, "end") | 534 | on_checkpoint(global_step + global_step_offset, "end") |
| 524 | raise KeyboardInterrupt | 535 | raise KeyboardInterrupt |
| 525 | 536 | ||
| 537 | return lrs, losses, accs | ||
| 538 | |||
| 526 | 539 | ||
| 527 | def train( | 540 | def train( |
| 528 | accelerator: Accelerator, | 541 | accelerator: Accelerator, |
| @@ -582,7 +595,7 @@ def train( | |||
| 582 | if accelerator.is_main_process: | 595 | if accelerator.is_main_process: |
| 583 | accelerator.init_trackers(project) | 596 | accelerator.init_trackers(project) |
| 584 | 597 | ||
| 585 | train_loop( | 598 | metrics = train_loop( |
| 586 | accelerator=accelerator, | 599 | accelerator=accelerator, |
| 587 | optimizer=optimizer, | 600 | optimizer=optimizer, |
| 588 | lr_scheduler=lr_scheduler, | 601 | lr_scheduler=lr_scheduler, |
| @@ -598,3 +611,5 @@ def train( | |||
| 598 | 611 | ||
| 599 | accelerator.end_training() | 612 | accelerator.end_training() |
| 600 | accelerator.free_memory() | 613 | accelerator.free_memory() |
| 614 | |||
| 615 | return metrics | ||
diff --git a/training/lr.py b/training/lr.py index 9690738..f5b362f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,238 +1,36 @@ | |||
| 1 | import math | 1 | from pathlib import Path |
| 2 | from contextlib import _GeneratorContextManager, nullcontext | ||
| 3 | from typing import Callable, Any, Tuple, Union | ||
| 4 | from functools import partial | ||
| 5 | 2 | ||
| 6 | import matplotlib.pyplot as plt | 3 | import matplotlib.pyplot as plt |
| 7 | import numpy as np | ||
| 8 | import torch | ||
| 9 | from torch.optim.lr_scheduler import LambdaLR | ||
| 10 | from tqdm.auto import tqdm | ||
| 11 | 4 | ||
| 12 | from training.functional import TrainingCallbacks | ||
| 13 | from training.util import AverageMeter | ||
| 14 | 5 | ||
| 6 | def plot_metrics( | ||
| 7 | metrics: tuple[list[float], list[float], list[float]], | ||
| 8 | output_file: Path, | ||
| 9 | skip_start: int = 10, | ||
| 10 | skip_end: int = 5, | ||
| 11 | ): | ||
| 12 | lrs, losses, accs = metrics | ||
| 15 | 13 | ||
| 16 | def noop(*args, **kwards): | 14 | if skip_end == 0: |
| 17 | pass | 15 | lrs = lrs[skip_start:] |
| 16 | losses = losses[skip_start:] | ||
| 17 | accs = accs[skip_start:] | ||
| 18 | else: | ||
| 19 | lrs = lrs[skip_start:-skip_end] | ||
| 20 | losses = losses[skip_start:-skip_end] | ||
| 21 | accs = accs[skip_start:-skip_end] | ||
| 18 | 22 | ||
| 23 | fig, ax_loss = plt.subplots() | ||
| 24 | ax_acc = ax_loss.twinx() | ||
| 19 | 25 | ||
| 20 | def noop_ctx(*args, **kwards): | 26 | ax_loss.plot(lrs, losses, color='red') |
| 21 | return nullcontext() | 27 | ax_loss.set_xscale("log") |
| 28 | ax_loss.set_xlabel(f"Learning rate") | ||
| 29 | ax_loss.set_ylabel("Loss") | ||
| 22 | 30 | ||
| 31 | ax_acc.plot(lrs, accs, color='blue') | ||
| 32 | ax_acc.set_xscale("log") | ||
| 33 | ax_acc.set_ylabel("Accuracy") | ||
| 23 | 34 | ||
| 24 | class LRFinder(): | 35 | plt.savefig(output_file, dpi=300) |
| 25 | def __init__( | 36 | plt.close() |
| 26 | self, | ||
| 27 | accelerator, | ||
| 28 | optimizer, | ||
| 29 | train_dataloader, | ||
| 30 | val_dataloader, | ||
| 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | ||
| 32 | callbacks: TrainingCallbacks = TrainingCallbacks() | ||
| 33 | ): | ||
| 34 | self.accelerator = accelerator | ||
| 35 | self.model = callbacks.on_model() | ||
| 36 | self.optimizer = optimizer | ||
| 37 | self.train_dataloader = train_dataloader | ||
| 38 | self.val_dataloader = val_dataloader | ||
| 39 | self.loss_fn = loss_fn | ||
| 40 | self.callbacks = callbacks | ||
| 41 | |||
| 42 | # self.model_state = copy.deepcopy(model.state_dict()) | ||
| 43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | ||
| 44 | |||
| 45 | def run( | ||
| 46 | self, | ||
| 47 | end_lr, | ||
| 48 | skip_start: int = 10, | ||
| 49 | skip_end: int = 5, | ||
| 50 | num_epochs: int = 100, | ||
| 51 | num_train_batches: int = math.inf, | ||
| 52 | num_val_batches: int = math.inf, | ||
| 53 | smooth_f: float = 0.05, | ||
| 54 | ): | ||
| 55 | best_loss = None | ||
| 56 | best_acc = None | ||
| 57 | |||
| 58 | lrs = [] | ||
| 59 | losses = [] | ||
| 60 | accs = [] | ||
| 61 | |||
| 62 | lr_scheduler = get_exponential_schedule( | ||
| 63 | self.optimizer, | ||
| 64 | end_lr, | ||
| 65 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
| 66 | ) | ||
| 67 | |||
| 68 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
| 69 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
| 70 | steps *= num_epochs | ||
| 71 | |||
| 72 | progress_bar = tqdm( | ||
| 73 | range(steps), | ||
| 74 | disable=not self.accelerator.is_local_main_process, | ||
| 75 | dynamic_ncols=True | ||
| 76 | ) | ||
| 77 | progress_bar.set_description("Epoch X / Y") | ||
| 78 | |||
| 79 | self.callbacks.on_prepare() | ||
| 80 | |||
| 81 | on_train = self.callbacks.on_train | ||
| 82 | on_before_optimize = self.callbacks.on_before_optimize | ||
| 83 | on_after_optimize = self.callbacks.on_after_optimize | ||
| 84 | on_eval = self.callbacks.on_eval | ||
| 85 | |||
| 86 | for epoch in range(num_epochs): | ||
| 87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
| 88 | |||
| 89 | avg_loss = AverageMeter() | ||
| 90 | avg_acc = AverageMeter() | ||
| 91 | |||
| 92 | self.model.train() | ||
| 93 | |||
| 94 | with on_train(epoch): | ||
| 95 | for step, batch in enumerate(self.train_dataloader): | ||
| 96 | if step >= num_train_batches: | ||
| 97 | break | ||
| 98 | |||
| 99 | with self.accelerator.accumulate(self.model): | ||
| 100 | loss, acc, bsz = self.loss_fn(step, batch) | ||
| 101 | |||
| 102 | self.accelerator.backward(loss) | ||
| 103 | |||
| 104 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | ||
| 105 | |||
| 106 | self.optimizer.step() | ||
| 107 | lr_scheduler.step() | ||
| 108 | self.optimizer.zero_grad(set_to_none=True) | ||
| 109 | |||
| 110 | if self.accelerator.sync_gradients: | ||
| 111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 112 | |||
| 113 | progress_bar.update(1) | ||
| 114 | |||
| 115 | self.model.eval() | ||
| 116 | |||
| 117 | with torch.inference_mode(): | ||
| 118 | with on_eval(): | ||
| 119 | for step, batch in enumerate(self.val_dataloader): | ||
| 120 | if step >= num_val_batches: | ||
| 121 | break | ||
| 122 | |||
| 123 | loss, acc, bsz = self.loss_fn(step, batch, True) | ||
| 124 | avg_loss.update(loss.detach_(), bsz) | ||
| 125 | avg_acc.update(acc.detach_(), bsz) | ||
| 126 | |||
| 127 | progress_bar.update(1) | ||
| 128 | |||
| 129 | loss = avg_loss.avg.item() | ||
| 130 | acc = avg_acc.avg.item() | ||
| 131 | |||
| 132 | if epoch == 0: | ||
| 133 | best_loss = loss | ||
| 134 | best_acc = acc | ||
| 135 | else: | ||
| 136 | if smooth_f > 0: | ||
| 137 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | ||
| 138 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
| 139 | if loss < best_loss: | ||
| 140 | best_loss = loss | ||
| 141 | if acc > best_acc: | ||
| 142 | best_acc = acc | ||
| 143 | |||
| 144 | lr = lr_scheduler.get_last_lr()[0] | ||
| 145 | |||
| 146 | lrs.append(lr) | ||
| 147 | losses.append(loss) | ||
| 148 | accs.append(acc) | ||
| 149 | |||
| 150 | self.accelerator.log({ | ||
| 151 | "loss": loss, | ||
| 152 | "acc": acc, | ||
| 153 | "lr": lr, | ||
| 154 | }, step=epoch) | ||
| 155 | |||
| 156 | progress_bar.set_postfix({ | ||
| 157 | "loss": loss, | ||
| 158 | "loss/best": best_loss, | ||
| 159 | "acc": acc, | ||
| 160 | "acc/best": best_acc, | ||
| 161 | "lr": lr, | ||
| 162 | }) | ||
| 163 | |||
| 164 | # self.model.load_state_dict(self.model_state) | ||
| 165 | # self.optimizer.load_state_dict(self.optimizer_state) | ||
| 166 | |||
| 167 | if skip_end == 0: | ||
| 168 | lrs = lrs[skip_start:] | ||
| 169 | losses = losses[skip_start:] | ||
| 170 | accs = accs[skip_start:] | ||
| 171 | else: | ||
| 172 | lrs = lrs[skip_start:-skip_end] | ||
| 173 | losses = losses[skip_start:-skip_end] | ||
| 174 | accs = accs[skip_start:-skip_end] | ||
| 175 | |||
| 176 | fig, ax_loss = plt.subplots() | ||
| 177 | ax_acc = ax_loss.twinx() | ||
| 178 | |||
| 179 | ax_loss.plot(lrs, losses, color='red') | ||
| 180 | ax_loss.set_xscale("log") | ||
| 181 | ax_loss.set_xlabel(f"Learning rate") | ||
| 182 | ax_loss.set_ylabel("Loss") | ||
| 183 | |||
| 184 | ax_acc.plot(lrs, accs, color='blue') | ||
| 185 | ax_acc.set_xscale("log") | ||
| 186 | ax_acc.set_ylabel("Accuracy") | ||
| 187 | |||
| 188 | print("LR suggestion: steepest gradient") | ||
| 189 | min_grad_idx = None | ||
| 190 | |||
| 191 | try: | ||
| 192 | min_grad_idx = np.gradient(np.array(losses)).argmin() | ||
| 193 | except ValueError: | ||
| 194 | print( | ||
| 195 | "Failed to compute the gradients, there might not be enough points." | ||
| 196 | ) | ||
| 197 | |||
| 198 | try: | ||
| 199 | max_val_idx = np.array(accs).argmax() | ||
| 200 | except ValueError: | ||
| 201 | print( | ||
| 202 | "Failed to compute the gradients, there might not be enough points." | ||
| 203 | ) | ||
| 204 | |||
| 205 | if min_grad_idx is not None: | ||
| 206 | print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) | ||
| 207 | ax_loss.scatter( | ||
| 208 | lrs[min_grad_idx], | ||
| 209 | losses[min_grad_idx], | ||
| 210 | s=75, | ||
| 211 | marker="o", | ||
| 212 | color="red", | ||
| 213 | zorder=3, | ||
| 214 | label="steepest gradient", | ||
| 215 | ) | ||
| 216 | ax_loss.legend() | ||
| 217 | |||
| 218 | if max_val_idx is not None: | ||
| 219 | print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) | ||
| 220 | ax_acc.scatter( | ||
| 221 | lrs[max_val_idx], | ||
| 222 | accs[max_val_idx], | ||
| 223 | s=75, | ||
| 224 | marker="o", | ||
| 225 | color="blue", | ||
| 226 | zorder=3, | ||
| 227 | label="maximum", | ||
| 228 | ) | ||
| 229 | ax_acc.legend() | ||
| 230 | |||
| 231 | |||
| 232 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): | ||
| 233 | def lr_lambda(base_lr: float, current_epoch: int): | ||
| 234 | return (end_lr / base_lr) ** (current_epoch / num_epochs) | ||
| 235 | |||
| 236 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
| 237 | |||
| 238 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||
diff --git a/training/optimization.py b/training/optimization.py index 6dee4bc..6c9a35d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -87,6 +87,15 @@ def get_one_cycle_schedule( | |||
| 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| 88 | 88 | ||
| 89 | 89 | ||
| 90 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | ||
| 91 | def lr_lambda(base_lr: float, current_step: int): | ||
| 92 | return (end_lr / base_lr) ** (current_step / num_training_steps) | ||
| 93 | |||
| 94 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
| 95 | |||
| 96 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||
| 97 | |||
| 98 | |||
| 90 | def get_scheduler( | 99 | def get_scheduler( |
| 91 | id: str, | 100 | id: str, |
| 92 | optimizer: torch.optim.Optimizer, | 101 | optimizer: torch.optim.Optimizer, |
| @@ -97,6 +106,7 @@ def get_scheduler( | |||
| 97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", | 106 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
| 98 | warmup_exp: int = 1, | 107 | warmup_exp: int = 1, |
| 99 | annealing_exp: int = 1, | 108 | annealing_exp: int = 1, |
| 109 | end_lr: float = 1e3, | ||
| 100 | cycles: int = 1, | 110 | cycles: int = 1, |
| 101 | train_epochs: int = 100, | 111 | train_epochs: int = 100, |
| 102 | warmup_epochs: int = 10, | 112 | warmup_epochs: int = 10, |
| @@ -117,6 +127,15 @@ def get_scheduler( | |||
| 117 | annealing_exp=annealing_exp, | 127 | annealing_exp=annealing_exp, |
| 118 | min_lr=min_lr, | 128 | min_lr=min_lr, |
| 119 | ) | 129 | ) |
| 130 | elif id == "exponential_growth": | ||
| 131 | if cycles is None: | ||
| 132 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | ||
| 133 | |||
| 134 | lr_scheduler = get_exponential_growing_schedule( | ||
| 135 | optimizer=optimizer, | ||
| 136 | end_lr=end_lr, | ||
| 137 | num_training_steps=num_training_steps, | ||
| 138 | ) | ||
| 120 | elif id == "cosine_with_restarts": | 139 | elif id == "cosine_with_restarts": |
| 121 | if cycles is None: | 140 | if cycles is None: |
| 122 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 141 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 1277939..e88bf90 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -193,9 +193,7 @@ def dreambooth_prepare( | |||
| 193 | unet: UNet2DConditionModel, | 193 | unet: UNet2DConditionModel, |
| 194 | *args | 194 | *args |
| 195 | ): | 195 | ): |
| 196 | prep = [text_encoder, unet] + list(args) | 196 | return accelerator.prepare(text_encoder, unet, *args) |
| 197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | 197 | ||
| 200 | 198 | ||
| 201 | dreambooth_strategy = TrainingStrategy( | 199 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a76f98..14bdafd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -176,10 +176,9 @@ def textual_inversion_prepare( | |||
| 176 | elif accelerator.state.mixed_precision == "bf16": | 176 | elif accelerator.state.mixed_precision == "bf16": |
| 177 | weight_dtype = torch.bfloat16 | 177 | weight_dtype = torch.bfloat16 |
| 178 | 178 | ||
| 179 | prep = [text_encoder] + list(args) | 179 | prepped = accelerator.prepare(text_encoder, *args) |
| 180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 181 | unet.to(accelerator.device, dtype=weight_dtype) | 180 | unet.to(accelerator.device, dtype=weight_dtype) |
| 182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 181 | return (prepped[0], unet) + prepped[1:] |
| 183 | 182 | ||
| 184 | 183 | ||
| 185 | textual_inversion_strategy = TrainingStrategy( | 184 | textual_inversion_strategy = TrainingStrategy( |
diff --git a/training/util.py b/training/util.py index 237626f..c8524de 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -6,6 +6,8 @@ from contextlib import contextmanager | |||
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| 8 | 8 | ||
| 9 | from diffusers.training_utils import EMAModel as EMAModel_ | ||
| 10 | |||
| 9 | 11 | ||
| 10 | def save_args(basepath: Path, args, extra={}): | 12 | def save_args(basepath: Path, args, extra={}): |
| 11 | info = {"args": vars(args)} | 13 | info = {"args": vars(args)} |
| @@ -30,149 +32,7 @@ class AverageMeter: | |||
| 30 | self.avg = self.sum / self.count | 32 | self.avg = self.sum / self.count |
| 31 | 33 | ||
| 32 | 34 | ||
| 33 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 35 | class EMAModel(EMAModel_): |
| 34 | class EMAModel: | ||
| 35 | """ | ||
| 36 | Exponential Moving Average of models weights | ||
| 37 | """ | ||
| 38 | |||
| 39 | def __init__( | ||
| 40 | self, | ||
| 41 | parameters: Iterable[torch.nn.Parameter], | ||
| 42 | update_after_step: int = 0, | ||
| 43 | inv_gamma: float = 1.0, | ||
| 44 | power: float = 2 / 3, | ||
| 45 | min_value: float = 0.0, | ||
| 46 | max_value: float = 0.9999, | ||
| 47 | ): | ||
| 48 | """ | ||
| 49 | @crowsonkb's notes on EMA Warmup: | ||
| 50 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
| 51 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
| 52 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
| 53 | at 215.4k steps). | ||
| 54 | Args: | ||
| 55 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
| 56 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
| 57 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
| 58 | """ | ||
| 59 | parameters = list(parameters) | ||
| 60 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
| 61 | |||
| 62 | self.collected_params = None | ||
| 63 | |||
| 64 | self.update_after_step = update_after_step | ||
| 65 | self.inv_gamma = inv_gamma | ||
| 66 | self.power = power | ||
| 67 | self.min_value = min_value | ||
| 68 | self.max_value = max_value | ||
| 69 | |||
| 70 | self.decay = 0.0 | ||
| 71 | self.optimization_step = 0 | ||
| 72 | |||
| 73 | def get_decay(self, optimization_step: int): | ||
| 74 | """ | ||
| 75 | Compute the decay factor for the exponential moving average. | ||
| 76 | """ | ||
| 77 | step = max(0, optimization_step - self.update_after_step - 1) | ||
| 78 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
| 79 | |||
| 80 | if step <= 0: | ||
| 81 | return 0.0 | ||
| 82 | |||
| 83 | return max(self.min_value, min(value, self.max_value)) | ||
| 84 | |||
| 85 | @torch.no_grad() | ||
| 86 | def step(self, parameters): | ||
| 87 | parameters = list(parameters) | ||
| 88 | |||
| 89 | self.optimization_step += 1 | ||
| 90 | |||
| 91 | # Compute the decay factor for the exponential moving average. | ||
| 92 | self.decay = self.get_decay(self.optimization_step) | ||
| 93 | |||
| 94 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 95 | if param.requires_grad: | ||
| 96 | s_param.mul_(self.decay) | ||
| 97 | s_param.add_(param.data, alpha=1 - self.decay) | ||
| 98 | else: | ||
| 99 | s_param.copy_(param) | ||
| 100 | |||
| 101 | torch.cuda.empty_cache() | ||
| 102 | |||
| 103 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
| 104 | """ | ||
| 105 | Copy current averaged parameters into given collection of parameters. | ||
| 106 | Args: | ||
| 107 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
| 108 | updated with the stored moving averages. If `None`, the | ||
| 109 | parameters with which this `ExponentialMovingAverage` was | ||
| 110 | initialized will be used. | ||
| 111 | """ | ||
| 112 | parameters = list(parameters) | ||
| 113 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 114 | param.data.copy_(s_param.data) | ||
| 115 | |||
| 116 | def to(self, device=None, dtype=None) -> None: | ||
| 117 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
| 118 | Args: | ||
| 119 | device: like `device` argument to `torch.Tensor.to` | ||
| 120 | """ | ||
| 121 | # .to() on the tensors handles None correctly | ||
| 122 | self.shadow_params = [ | ||
| 123 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
| 124 | for p in self.shadow_params | ||
| 125 | ] | ||
| 126 | |||
| 127 | def state_dict(self) -> dict: | ||
| 128 | r""" | ||
| 129 | Returns the state of the ExponentialMovingAverage as a dict. | ||
| 130 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 131 | """ | ||
| 132 | # Following PyTorch conventions, references to tensors are returned: | ||
| 133 | # "returns a reference to the state and not its copy!" - | ||
| 134 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
| 135 | return { | ||
| 136 | "decay": self.decay, | ||
| 137 | "optimization_step": self.optimization_step, | ||
| 138 | "shadow_params": self.shadow_params, | ||
| 139 | "collected_params": self.collected_params, | ||
| 140 | } | ||
| 141 | |||
| 142 | def load_state_dict(self, state_dict: dict) -> None: | ||
| 143 | r""" | ||
| 144 | Loads the ExponentialMovingAverage state. | ||
| 145 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 146 | Args: | ||
| 147 | state_dict (dict): EMA state. Should be an object returned | ||
| 148 | from a call to :meth:`state_dict`. | ||
| 149 | """ | ||
| 150 | # deepcopy, to be consistent with module API | ||
| 151 | state_dict = copy.deepcopy(state_dict) | ||
| 152 | |||
| 153 | self.decay = state_dict["decay"] | ||
| 154 | if self.decay < 0.0 or self.decay > 1.0: | ||
| 155 | raise ValueError("Decay must be between 0 and 1") | ||
| 156 | |||
| 157 | self.optimization_step = state_dict["optimization_step"] | ||
| 158 | if not isinstance(self.optimization_step, int): | ||
| 159 | raise ValueError("Invalid optimization_step") | ||
| 160 | |||
| 161 | self.shadow_params = state_dict["shadow_params"] | ||
| 162 | if not isinstance(self.shadow_params, list): | ||
| 163 | raise ValueError("shadow_params must be a list") | ||
| 164 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
| 165 | raise ValueError("shadow_params must all be Tensors") | ||
| 166 | |||
| 167 | self.collected_params = state_dict["collected_params"] | ||
| 168 | if self.collected_params is not None: | ||
| 169 | if not isinstance(self.collected_params, list): | ||
| 170 | raise ValueError("collected_params must be a list") | ||
| 171 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
| 172 | raise ValueError("collected_params must all be Tensors") | ||
| 173 | if len(self.collected_params) != len(self.shadow_params): | ||
| 174 | raise ValueError("collected_params and shadow_params must have the same length") | ||
| 175 | |||
| 176 | @contextmanager | 36 | @contextmanager |
| 177 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 37 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
| 178 | parameters = list(parameters) | 38 | parameters = list(parameters) |
