diff options
| -rw-r--r-- | train_dreambooth.py | 5 | ||||
| -rw-r--r-- | train_lora.py | 75 | ||||
| -rw-r--r-- | train_ti.py | 12 | ||||
| -rw-r--r-- | training/functional.py | 57 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/lora.py | 12 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
7 files changed, 80 insertions, 85 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48921d4..f4d4cbb 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -18,7 +18,6 @@ import transformers | |||
| 18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| 20 | from training.functional import train, get_models | 20 | from training.functional import train, get_models |
| 21 | from training.lr import plot_metrics | ||
| 22 | from training.strategy.dreambooth import dreambooth_strategy | 21 | from training.strategy.dreambooth import dreambooth_strategy |
| 23 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
| 24 | from training.util import save_args | 23 | from training.util import save_args |
| @@ -692,7 +691,7 @@ def main(): | |||
| 692 | mid_point=args.lr_mid_point, | 691 | mid_point=args.lr_mid_point, |
| 693 | ) | 692 | ) |
| 694 | 693 | ||
| 695 | metrics = trainer( | 694 | trainer( |
| 696 | strategy=dreambooth_strategy, | 695 | strategy=dreambooth_strategy, |
| 697 | project="dreambooth", | 696 | project="dreambooth", |
| 698 | train_dataloader=datamodule.train_dataloader, | 697 | train_dataloader=datamodule.train_dataloader, |
| @@ -721,8 +720,6 @@ def main(): | |||
| 721 | sample_image_size=args.sample_image_size, | 720 | sample_image_size=args.sample_image_size, |
| 722 | ) | 721 | ) |
| 723 | 722 | ||
| 724 | plot_metrics(metrics, output_dir / "lr.png") | ||
| 725 | |||
| 726 | 723 | ||
| 727 | if __name__ == "__main__": | 724 | if __name__ == "__main__": |
| 728 | main() | 725 | main() |
diff --git a/train_lora.py b/train_lora.py index 9f17495..1626be6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -19,7 +19,6 @@ import transformers | |||
| 19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
| 22 | from training.lr import plot_metrics | ||
| 23 | from training.strategy.lora import lora_strategy | 22 | from training.strategy.lora import lora_strategy |
| 24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
| 25 | from training.util import save_args | 24 | from training.util import save_args |
| @@ -568,6 +567,9 @@ def parse_args(): | |||
| 568 | if len(args.placeholder_tokens) == 0: | 567 | if len(args.placeholder_tokens) == 0: |
| 569 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 568 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
| 570 | 569 | ||
| 570 | if len(args.initializer_tokens) == 0: | ||
| 571 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
| 572 | |||
| 571 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 573 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 572 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 574 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 573 | 575 | ||
| @@ -918,7 +920,7 @@ def main(): | |||
| 918 | train_epochs=num_pti_epochs, | 920 | train_epochs=num_pti_epochs, |
| 919 | ) | 921 | ) |
| 920 | 922 | ||
| 921 | metrics = trainer( | 923 | trainer( |
| 922 | strategy=lora_strategy, | 924 | strategy=lora_strategy, |
| 923 | pti_mode=True, | 925 | pti_mode=True, |
| 924 | project="pti", | 926 | project="pti", |
| @@ -929,12 +931,13 @@ def main(): | |||
| 929 | num_train_epochs=num_pti_epochs, | 931 | num_train_epochs=num_pti_epochs, |
| 930 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 932 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 931 | # -- | 933 | # -- |
| 934 | group_labels=["emb"], | ||
| 932 | sample_output_dir=pti_sample_output_dir, | 935 | sample_output_dir=pti_sample_output_dir, |
| 933 | checkpoint_output_dir=pti_checkpoint_output_dir, | 936 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 934 | sample_frequency=pti_sample_frequency, | 937 | sample_frequency=pti_sample_frequency, |
| 935 | ) | 938 | ) |
| 936 | 939 | ||
| 937 | plot_metrics(metrics, pti_output_dir / "lr.png") | 940 | # embeddings.persist() |
| 938 | 941 | ||
| 939 | # LORA | 942 | # LORA |
| 940 | # -------------------------------------------------------------------------------- | 943 | # -------------------------------------------------------------------------------- |
| @@ -957,34 +960,39 @@ def main(): | |||
| 957 | ) * args.gradient_accumulation_steps | 960 | ) * args.gradient_accumulation_steps |
| 958 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 961 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
| 959 | 962 | ||
| 960 | lora_optimizer = create_optimizer( | 963 | params_to_optimize = [] |
| 961 | [ | 964 | group_labels = [] |
| 962 | { | 965 | if len(args.placeholder_tokens) != 0: |
| 963 | "params": ( | 966 | params_to_optimize.append({ |
| 964 | param | 967 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
| 965 | for param in unet.parameters() | 968 | "lr": args.learning_rate_text, |
| 966 | if param.requires_grad | 969 | "weight_decay": 0, |
| 967 | ), | 970 | }) |
| 968 | "lr": args.learning_rate_unet, | 971 | group_labels.append("emb") |
| 969 | }, | 972 | params_to_optimize += [ |
| 970 | { | 973 | { |
| 971 | "params": ( | 974 | "params": ( |
| 972 | param | 975 | param |
| 973 | for param in itertools.chain( | 976 | for param in unet.parameters() |
| 974 | text_encoder.text_model.encoder.parameters(), | 977 | if param.requires_grad |
| 975 | text_encoder.text_model.final_layer_norm.parameters(), | 978 | ), |
| 976 | ) | 979 | "lr": args.learning_rate_unet, |
| 977 | if param.requires_grad | 980 | }, |
| 978 | ), | 981 | { |
| 979 | "lr": args.learning_rate_text, | 982 | "params": ( |
| 980 | }, | 983 | param |
| 981 | { | 984 | for param in itertools.chain( |
| 982 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 985 | text_encoder.text_model.encoder.parameters(), |
| 983 | "lr": args.learning_rate_text, | 986 | text_encoder.text_model.final_layer_norm.parameters(), |
| 984 | "weight_decay": 0, | 987 | ) |
| 985 | }, | 988 | if param.requires_grad |
| 986 | ] | 989 | ), |
| 987 | ) | 990 | "lr": args.learning_rate_text, |
| 991 | }, | ||
| 992 | ] | ||
| 993 | group_labels += ["unet", "text"] | ||
| 994 | |||
| 995 | lora_optimizer = create_optimizer(params_to_optimize) | ||
| 988 | 996 | ||
| 989 | lora_lr_scheduler = create_lr_scheduler( | 997 | lora_lr_scheduler = create_lr_scheduler( |
| 990 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 998 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| @@ -993,7 +1001,7 @@ def main(): | |||
| 993 | train_epochs=num_train_epochs, | 1001 | train_epochs=num_train_epochs, |
| 994 | ) | 1002 | ) |
| 995 | 1003 | ||
| 996 | metrics = trainer( | 1004 | trainer( |
| 997 | strategy=lora_strategy, | 1005 | strategy=lora_strategy, |
| 998 | project="lora", | 1006 | project="lora", |
| 999 | train_dataloader=lora_datamodule.train_dataloader, | 1007 | train_dataloader=lora_datamodule.train_dataloader, |
| @@ -1003,13 +1011,12 @@ def main(): | |||
| 1003 | num_train_epochs=num_train_epochs, | 1011 | num_train_epochs=num_train_epochs, |
| 1004 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1012 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1005 | # -- | 1013 | # -- |
| 1014 | group_labels=group_labels, | ||
| 1006 | sample_output_dir=lora_sample_output_dir, | 1015 | sample_output_dir=lora_sample_output_dir, |
| 1007 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1016 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 1008 | sample_frequency=lora_sample_frequency, | 1017 | sample_frequency=lora_sample_frequency, |
| 1009 | ) | 1018 | ) |
| 1010 | 1019 | ||
| 1011 | plot_metrics(metrics, lora_output_dir / "lr.png") | ||
| 1012 | |||
| 1013 | 1020 | ||
| 1014 | if __name__ == "__main__": | 1021 | if __name__ == "__main__": |
| 1015 | main() | 1022 | main() |
diff --git a/train_ti.py b/train_ti.py index c1c0eed..48858cc 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -17,7 +17,6 @@ import transformers | |||
| 17 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
| 18 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
| 19 | from training.functional import train, add_placeholder_tokens, get_models | 19 | from training.functional import train, add_placeholder_tokens, get_models |
| 20 | from training.lr import plot_metrics | ||
| 21 | from training.strategy.ti import textual_inversion_strategy | 20 | from training.strategy.ti import textual_inversion_strategy |
| 22 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
| 23 | from training.util import save_args | 22 | from training.util import save_args |
| @@ -511,12 +510,12 @@ def parse_args(): | |||
| 511 | if isinstance(args.initializer_tokens, str): | 510 | if isinstance(args.initializer_tokens, str): |
| 512 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 511 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
| 513 | 512 | ||
| 514 | if len(args.initializer_tokens) == 0: | ||
| 515 | raise ValueError("You must specify --initializer_tokens") | ||
| 516 | |||
| 517 | if len(args.placeholder_tokens) == 0: | 513 | if len(args.placeholder_tokens) == 0: |
| 518 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 514 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
| 519 | 515 | ||
| 516 | if len(args.initializer_tokens) == 0: | ||
| 517 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
| 518 | |||
| 520 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 519 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 521 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 520 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 522 | 521 | ||
| @@ -856,7 +855,7 @@ def main(): | |||
| 856 | mid_point=args.lr_mid_point, | 855 | mid_point=args.lr_mid_point, |
| 857 | ) | 856 | ) |
| 858 | 857 | ||
| 859 | metrics = trainer( | 858 | trainer( |
| 860 | project="textual_inversion", | 859 | project="textual_inversion", |
| 861 | train_dataloader=datamodule.train_dataloader, | 860 | train_dataloader=datamodule.train_dataloader, |
| 862 | val_dataloader=datamodule.val_dataloader, | 861 | val_dataloader=datamodule.val_dataloader, |
| @@ -864,14 +863,13 @@ def main(): | |||
| 864 | lr_scheduler=lr_scheduler, | 863 | lr_scheduler=lr_scheduler, |
| 865 | num_train_epochs=num_train_epochs, | 864 | num_train_epochs=num_train_epochs, |
| 866 | # -- | 865 | # -- |
| 866 | group_labels=["emb"], | ||
| 867 | sample_output_dir=sample_output_dir, | 867 | sample_output_dir=sample_output_dir, |
| 868 | sample_frequency=sample_frequency, | 868 | sample_frequency=sample_frequency, |
| 869 | placeholder_tokens=placeholder_tokens, | 869 | placeholder_tokens=placeholder_tokens, |
| 870 | placeholder_token_ids=placeholder_token_ids, | 870 | placeholder_token_ids=placeholder_token_ids, |
| 871 | ) | 871 | ) |
| 872 | 872 | ||
| 873 | plot_metrics(metrics, metrics_output_file) | ||
| 874 | |||
| 875 | if not args.sequential: | 873 | if not args.sequential: |
| 876 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 874 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 877 | else: | 875 | else: |
diff --git a/training/functional.py b/training/functional.py index 4d83df1..71b2fe9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -36,8 +36,8 @@ def const(result=None): | |||
| 36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
| 37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 39 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[int], Any] = const() |
| 40 | on_after_optimize: Callable[[Any, float], None] = const() | 40 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
| 41 | on_after_epoch: Callable[[], None] = const() | 41 | on_after_epoch: Callable[[], None] = const() |
| 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| 43 | on_sample: Callable[[int], None] = const() | 43 | on_sample: Callable[[int], None] = const() |
| @@ -422,6 +422,7 @@ def train_loop( | |||
| 422 | global_step_offset: int = 0, | 422 | global_step_offset: int = 0, |
| 423 | num_epochs: int = 100, | 423 | num_epochs: int = 100, |
| 424 | gradient_accumulation_steps: int = 1, | 424 | gradient_accumulation_steps: int = 1, |
| 425 | group_labels: list[str] = [], | ||
| 425 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 426 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 426 | ): | 427 | ): |
| 427 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 428 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| @@ -442,10 +443,6 @@ def train_loop( | |||
| 442 | best_acc = 0.0 | 443 | best_acc = 0.0 |
| 443 | best_acc_val = 0.0 | 444 | best_acc_val = 0.0 |
| 444 | 445 | ||
| 445 | lrs = [] | ||
| 446 | losses = [] | ||
| 447 | accs = [] | ||
| 448 | |||
| 449 | local_progress_bar = tqdm( | 446 | local_progress_bar = tqdm( |
| 450 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 447 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| 451 | disable=not accelerator.is_local_main_process, | 448 | disable=not accelerator.is_local_main_process, |
| @@ -496,6 +493,8 @@ def train_loop( | |||
| 496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 493 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 497 | local_progress_bar.reset() | 494 | local_progress_bar.reset() |
| 498 | 495 | ||
| 496 | logs = {} | ||
| 497 | |||
| 499 | with on_train(epoch): | 498 | with on_train(epoch): |
| 500 | for step, batch in enumerate(train_dataloader): | 499 | for step, batch in enumerate(train_dataloader): |
| 501 | loss, acc, bsz = loss_step(step, batch, cache) | 500 | loss, acc, bsz = loss_step(step, batch, cache) |
| @@ -506,31 +505,36 @@ def train_loop( | |||
| 506 | avg_loss.update(loss.detach_(), bsz) | 505 | avg_loss.update(loss.detach_(), bsz) |
| 507 | avg_acc.update(acc.detach_(), bsz) | 506 | avg_acc.update(acc.detach_(), bsz) |
| 508 | 507 | ||
| 509 | lr = lr_scheduler.get_last_lr()[0] | ||
| 510 | if torch.is_tensor(lr): | ||
| 511 | lr = lr.item() | ||
| 512 | |||
| 513 | logs = { | 508 | logs = { |
| 514 | "train/loss": avg_loss.avg.item(), | 509 | "train/loss": avg_loss.avg.item(), |
| 515 | "train/acc": avg_acc.avg.item(), | 510 | "train/acc": avg_acc.avg.item(), |
| 516 | "train/cur_loss": loss.item(), | 511 | "train/cur_loss": loss.item(), |
| 517 | "train/cur_acc": acc.item(), | 512 | "train/cur_acc": acc.item(), |
| 518 | "lr": lr, | ||
| 519 | } | 513 | } |
| 520 | if isDadaptation: | 514 | |
| 521 | logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 515 | lrs: dict[str, float] = {} |
| 516 | for i, lr in enumerate(lr_scheduler.get_last_lr()): | ||
| 517 | if torch.is_tensor(lr): | ||
| 518 | lr = lr.item() | ||
| 519 | label = group_labels[i] if i < len(group_labels) else f"{i}" | ||
| 520 | logs[f"lr/{label}"] = lr | ||
| 521 | if isDadaptation: | ||
| 522 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | ||
| 523 | logs[f"d*lr/{label}"] = lr | ||
| 524 | lrs[label] = lr | ||
| 525 | |||
| 522 | logs.update(on_log()) | 526 | logs.update(on_log()) |
| 523 | 527 | ||
| 524 | local_progress_bar.set_postfix(**logs) | 528 | local_progress_bar.set_postfix(**logs) |
| 525 | 529 | ||
| 526 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 530 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
| 527 | before_optimize_result = on_before_optimize(lr, epoch) | 531 | before_optimize_result = on_before_optimize(epoch) |
| 528 | 532 | ||
| 529 | optimizer.step() | 533 | optimizer.step() |
| 530 | lr_scheduler.step() | 534 | lr_scheduler.step() |
| 531 | optimizer.zero_grad(set_to_none=True) | 535 | optimizer.zero_grad(set_to_none=True) |
| 532 | 536 | ||
| 533 | on_after_optimize(before_optimize_result, lr) | 537 | on_after_optimize(before_optimize_result, lrs) |
| 534 | 538 | ||
| 535 | local_progress_bar.update(1) | 539 | local_progress_bar.update(1) |
| 536 | global_progress_bar.update(1) | 540 | global_progress_bar.update(1) |
| @@ -544,15 +548,6 @@ def train_loop( | |||
| 544 | 548 | ||
| 545 | accelerator.wait_for_everyone() | 549 | accelerator.wait_for_everyone() |
| 546 | 550 | ||
| 547 | if isDadaptation: | ||
| 548 | lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | ||
| 549 | else: | ||
| 550 | lr = lr_scheduler.get_last_lr()[0] | ||
| 551 | if torch.is_tensor(lr): | ||
| 552 | lr = lr.item() | ||
| 553 | |||
| 554 | lrs.append(lr) | ||
| 555 | |||
| 556 | on_after_epoch() | 551 | on_after_epoch() |
| 557 | 552 | ||
| 558 | if val_dataloader is not None: | 553 | if val_dataloader is not None: |
| @@ -597,9 +592,6 @@ def train_loop( | |||
| 597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 592 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 598 | on_checkpoint(global_step + global_step_offset, "milestone") | 593 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 599 | best_acc_val = avg_acc_val.avg.item() | 594 | best_acc_val = avg_acc_val.avg.item() |
| 600 | |||
| 601 | losses.append(avg_loss_val.avg.item()) | ||
| 602 | accs.append(avg_acc_val.avg.item()) | ||
| 603 | else: | 595 | else: |
| 604 | if accelerator.is_main_process: | 596 | if accelerator.is_main_process: |
| 605 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 597 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
| @@ -611,9 +603,6 @@ def train_loop( | |||
| 611 | on_checkpoint(global_step + global_step_offset, "milestone") | 603 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 612 | best_acc = avg_acc.avg.item() | 604 | best_acc = avg_acc.avg.item() |
| 613 | 605 | ||
| 614 | losses.append(avg_loss.avg.item()) | ||
| 615 | accs.append(avg_acc.avg.item()) | ||
| 616 | |||
| 617 | # Create the pipeline using using the trained modules and save it. | 606 | # Create the pipeline using using the trained modules and save it. |
| 618 | if accelerator.is_main_process: | 607 | if accelerator.is_main_process: |
| 619 | print("Finished!") | 608 | print("Finished!") |
| @@ -626,8 +615,6 @@ def train_loop( | |||
| 626 | on_checkpoint(global_step + global_step_offset, "end") | 615 | on_checkpoint(global_step + global_step_offset, "end") |
| 627 | raise KeyboardInterrupt | 616 | raise KeyboardInterrupt |
| 628 | 617 | ||
| 629 | return lrs, losses, accs | ||
| 630 | |||
| 631 | 618 | ||
| 632 | def train( | 619 | def train( |
| 633 | accelerator: Accelerator, | 620 | accelerator: Accelerator, |
| @@ -646,6 +633,7 @@ def train( | |||
| 646 | no_val: bool = False, | 633 | no_val: bool = False, |
| 647 | num_train_epochs: int = 100, | 634 | num_train_epochs: int = 100, |
| 648 | gradient_accumulation_steps: int = 1, | 635 | gradient_accumulation_steps: int = 1, |
| 636 | group_labels: list[str] = [], | ||
| 649 | sample_frequency: int = 20, | 637 | sample_frequency: int = 20, |
| 650 | checkpoint_frequency: int = 50, | 638 | checkpoint_frequency: int = 50, |
| 651 | milestone_checkpoints: bool = True, | 639 | milestone_checkpoints: bool = True, |
| @@ -692,7 +680,7 @@ def train( | |||
| 692 | if accelerator.is_main_process: | 680 | if accelerator.is_main_process: |
| 693 | accelerator.init_trackers(project) | 681 | accelerator.init_trackers(project) |
| 694 | 682 | ||
| 695 | metrics = train_loop( | 683 | train_loop( |
| 696 | accelerator=accelerator, | 684 | accelerator=accelerator, |
| 697 | optimizer=optimizer, | 685 | optimizer=optimizer, |
| 698 | lr_scheduler=lr_scheduler, | 686 | lr_scheduler=lr_scheduler, |
| @@ -705,10 +693,9 @@ def train( | |||
| 705 | global_step_offset=global_step_offset, | 693 | global_step_offset=global_step_offset, |
| 706 | num_epochs=num_train_epochs, | 694 | num_epochs=num_train_epochs, |
| 707 | gradient_accumulation_steps=gradient_accumulation_steps, | 695 | gradient_accumulation_steps=gradient_accumulation_steps, |
| 696 | group_labels=group_labels, | ||
| 708 | callbacks=callbacks, | 697 | callbacks=callbacks, |
| 709 | ) | 698 | ) |
| 710 | 699 | ||
| 711 | accelerator.end_training() | 700 | accelerator.end_training() |
| 712 | accelerator.free_memory() | 701 | accelerator.free_memory() |
| 713 | |||
| 714 | return metrics | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 106 | with ema_context(): |
| 107 | yield | 107 | yield |
| 108 | 108 | ||
| 109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
| 110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
| 79 | tokenizer.eval() | 79 | tokenizer.eval() |
| 80 | yield | 80 | yield |
| 81 | 81 | ||
| 82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
| 83 | if not pti_mode: | 83 | if not pti_mode: |
| 84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
| 85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
| 86 | unet.parameters(), | ||
| 87 | text_encoder.text_model.encoder.parameters(), | ||
| 88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 89 | ), | ||
| 86 | max_grad_norm | 90 | max_grad_norm |
| 87 | ) | 91 | ) |
| 88 | 92 | ||
| @@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
| 95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
| 96 | 100 | ||
| 97 | @torch.no_grad() | 101 | @torch.no_grad() |
| 98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
| 103 | lr = lrs["emb"] or lrs["0"] | ||
| 104 | |||
| 99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
| 100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
| 101 | 107 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(lr: float, epoch: int): | 107 | def on_before_optimize(epoch: int): |
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
