diff options
-rw-r--r-- | train_dreambooth.py | 5 | ||||
-rw-r--r-- | train_lora.py | 75 | ||||
-rw-r--r-- | train_ti.py | 12 | ||||
-rw-r--r-- | training/functional.py | 57 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/lora.py | 12 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
7 files changed, 80 insertions, 85 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48921d4..f4d4cbb 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -18,7 +18,6 @@ import transformers | |||
18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 20 | from training.functional import train, get_models |
21 | from training.lr import plot_metrics | ||
22 | from training.strategy.dreambooth import dreambooth_strategy | 21 | from training.strategy.dreambooth import dreambooth_strategy |
23 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
24 | from training.util import save_args | 23 | from training.util import save_args |
@@ -692,7 +691,7 @@ def main(): | |||
692 | mid_point=args.lr_mid_point, | 691 | mid_point=args.lr_mid_point, |
693 | ) | 692 | ) |
694 | 693 | ||
695 | metrics = trainer( | 694 | trainer( |
696 | strategy=dreambooth_strategy, | 695 | strategy=dreambooth_strategy, |
697 | project="dreambooth", | 696 | project="dreambooth", |
698 | train_dataloader=datamodule.train_dataloader, | 697 | train_dataloader=datamodule.train_dataloader, |
@@ -721,8 +720,6 @@ def main(): | |||
721 | sample_image_size=args.sample_image_size, | 720 | sample_image_size=args.sample_image_size, |
722 | ) | 721 | ) |
723 | 722 | ||
724 | plot_metrics(metrics, output_dir / "lr.png") | ||
725 | |||
726 | 723 | ||
727 | if __name__ == "__main__": | 724 | if __name__ == "__main__": |
728 | main() | 725 | main() |
diff --git a/train_lora.py b/train_lora.py index 9f17495..1626be6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -19,7 +19,6 @@ import transformers | |||
19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
21 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
22 | from training.lr import plot_metrics | ||
23 | from training.strategy.lora import lora_strategy | 22 | from training.strategy.lora import lora_strategy |
24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
25 | from training.util import save_args | 24 | from training.util import save_args |
@@ -568,6 +567,9 @@ def parse_args(): | |||
568 | if len(args.placeholder_tokens) == 0: | 567 | if len(args.placeholder_tokens) == 0: |
569 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 568 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
570 | 569 | ||
570 | if len(args.initializer_tokens) == 0: | ||
571 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
572 | |||
571 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 573 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
572 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 574 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
573 | 575 | ||
@@ -918,7 +920,7 @@ def main(): | |||
918 | train_epochs=num_pti_epochs, | 920 | train_epochs=num_pti_epochs, |
919 | ) | 921 | ) |
920 | 922 | ||
921 | metrics = trainer( | 923 | trainer( |
922 | strategy=lora_strategy, | 924 | strategy=lora_strategy, |
923 | pti_mode=True, | 925 | pti_mode=True, |
924 | project="pti", | 926 | project="pti", |
@@ -929,12 +931,13 @@ def main(): | |||
929 | num_train_epochs=num_pti_epochs, | 931 | num_train_epochs=num_pti_epochs, |
930 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 932 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
931 | # -- | 933 | # -- |
934 | group_labels=["emb"], | ||
932 | sample_output_dir=pti_sample_output_dir, | 935 | sample_output_dir=pti_sample_output_dir, |
933 | checkpoint_output_dir=pti_checkpoint_output_dir, | 936 | checkpoint_output_dir=pti_checkpoint_output_dir, |
934 | sample_frequency=pti_sample_frequency, | 937 | sample_frequency=pti_sample_frequency, |
935 | ) | 938 | ) |
936 | 939 | ||
937 | plot_metrics(metrics, pti_output_dir / "lr.png") | 940 | # embeddings.persist() |
938 | 941 | ||
939 | # LORA | 942 | # LORA |
940 | # -------------------------------------------------------------------------------- | 943 | # -------------------------------------------------------------------------------- |
@@ -957,34 +960,39 @@ def main(): | |||
957 | ) * args.gradient_accumulation_steps | 960 | ) * args.gradient_accumulation_steps |
958 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 961 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
959 | 962 | ||
960 | lora_optimizer = create_optimizer( | 963 | params_to_optimize = [] |
961 | [ | 964 | group_labels = [] |
962 | { | 965 | if len(args.placeholder_tokens) != 0: |
963 | "params": ( | 966 | params_to_optimize.append({ |
964 | param | 967 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
965 | for param in unet.parameters() | 968 | "lr": args.learning_rate_text, |
966 | if param.requires_grad | 969 | "weight_decay": 0, |
967 | ), | 970 | }) |
968 | "lr": args.learning_rate_unet, | 971 | group_labels.append("emb") |
969 | }, | 972 | params_to_optimize += [ |
970 | { | 973 | { |
971 | "params": ( | 974 | "params": ( |
972 | param | 975 | param |
973 | for param in itertools.chain( | 976 | for param in unet.parameters() |
974 | text_encoder.text_model.encoder.parameters(), | 977 | if param.requires_grad |
975 | text_encoder.text_model.final_layer_norm.parameters(), | 978 | ), |
976 | ) | 979 | "lr": args.learning_rate_unet, |
977 | if param.requires_grad | 980 | }, |
978 | ), | 981 | { |
979 | "lr": args.learning_rate_text, | 982 | "params": ( |
980 | }, | 983 | param |
981 | { | 984 | for param in itertools.chain( |
982 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 985 | text_encoder.text_model.encoder.parameters(), |
983 | "lr": args.learning_rate_text, | 986 | text_encoder.text_model.final_layer_norm.parameters(), |
984 | "weight_decay": 0, | 987 | ) |
985 | }, | 988 | if param.requires_grad |
986 | ] | 989 | ), |
987 | ) | 990 | "lr": args.learning_rate_text, |
991 | }, | ||
992 | ] | ||
993 | group_labels += ["unet", "text"] | ||
994 | |||
995 | lora_optimizer = create_optimizer(params_to_optimize) | ||
988 | 996 | ||
989 | lora_lr_scheduler = create_lr_scheduler( | 997 | lora_lr_scheduler = create_lr_scheduler( |
990 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 998 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
@@ -993,7 +1001,7 @@ def main(): | |||
993 | train_epochs=num_train_epochs, | 1001 | train_epochs=num_train_epochs, |
994 | ) | 1002 | ) |
995 | 1003 | ||
996 | metrics = trainer( | 1004 | trainer( |
997 | strategy=lora_strategy, | 1005 | strategy=lora_strategy, |
998 | project="lora", | 1006 | project="lora", |
999 | train_dataloader=lora_datamodule.train_dataloader, | 1007 | train_dataloader=lora_datamodule.train_dataloader, |
@@ -1003,13 +1011,12 @@ def main(): | |||
1003 | num_train_epochs=num_train_epochs, | 1011 | num_train_epochs=num_train_epochs, |
1004 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1012 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1005 | # -- | 1013 | # -- |
1014 | group_labels=group_labels, | ||
1006 | sample_output_dir=lora_sample_output_dir, | 1015 | sample_output_dir=lora_sample_output_dir, |
1007 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1016 | checkpoint_output_dir=lora_checkpoint_output_dir, |
1008 | sample_frequency=lora_sample_frequency, | 1017 | sample_frequency=lora_sample_frequency, |
1009 | ) | 1018 | ) |
1010 | 1019 | ||
1011 | plot_metrics(metrics, lora_output_dir / "lr.png") | ||
1012 | |||
1013 | 1020 | ||
1014 | if __name__ == "__main__": | 1021 | if __name__ == "__main__": |
1015 | main() | 1022 | main() |
diff --git a/train_ti.py b/train_ti.py index c1c0eed..48858cc 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -17,7 +17,6 @@ import transformers | |||
17 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
18 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
19 | from training.functional import train, add_placeholder_tokens, get_models | 19 | from training.functional import train, add_placeholder_tokens, get_models |
20 | from training.lr import plot_metrics | ||
21 | from training.strategy.ti import textual_inversion_strategy | 20 | from training.strategy.ti import textual_inversion_strategy |
22 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
23 | from training.util import save_args | 22 | from training.util import save_args |
@@ -511,12 +510,12 @@ def parse_args(): | |||
511 | if isinstance(args.initializer_tokens, str): | 510 | if isinstance(args.initializer_tokens, str): |
512 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 511 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
513 | 512 | ||
514 | if len(args.initializer_tokens) == 0: | ||
515 | raise ValueError("You must specify --initializer_tokens") | ||
516 | |||
517 | if len(args.placeholder_tokens) == 0: | 513 | if len(args.placeholder_tokens) == 0: |
518 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 514 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
519 | 515 | ||
516 | if len(args.initializer_tokens) == 0: | ||
517 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
518 | |||
520 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 519 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
521 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 520 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
522 | 521 | ||
@@ -856,7 +855,7 @@ def main(): | |||
856 | mid_point=args.lr_mid_point, | 855 | mid_point=args.lr_mid_point, |
857 | ) | 856 | ) |
858 | 857 | ||
859 | metrics = trainer( | 858 | trainer( |
860 | project="textual_inversion", | 859 | project="textual_inversion", |
861 | train_dataloader=datamodule.train_dataloader, | 860 | train_dataloader=datamodule.train_dataloader, |
862 | val_dataloader=datamodule.val_dataloader, | 861 | val_dataloader=datamodule.val_dataloader, |
@@ -864,14 +863,13 @@ def main(): | |||
864 | lr_scheduler=lr_scheduler, | 863 | lr_scheduler=lr_scheduler, |
865 | num_train_epochs=num_train_epochs, | 864 | num_train_epochs=num_train_epochs, |
866 | # -- | 865 | # -- |
866 | group_labels=["emb"], | ||
867 | sample_output_dir=sample_output_dir, | 867 | sample_output_dir=sample_output_dir, |
868 | sample_frequency=sample_frequency, | 868 | sample_frequency=sample_frequency, |
869 | placeholder_tokens=placeholder_tokens, | 869 | placeholder_tokens=placeholder_tokens, |
870 | placeholder_token_ids=placeholder_token_ids, | 870 | placeholder_token_ids=placeholder_token_ids, |
871 | ) | 871 | ) |
872 | 872 | ||
873 | plot_metrics(metrics, metrics_output_file) | ||
874 | |||
875 | if not args.sequential: | 873 | if not args.sequential: |
876 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 874 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
877 | else: | 875 | else: |
diff --git a/training/functional.py b/training/functional.py index 4d83df1..71b2fe9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -36,8 +36,8 @@ def const(result=None): | |||
36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
39 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[int], Any] = const() |
40 | on_after_optimize: Callable[[Any, float], None] = const() | 40 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
41 | on_after_epoch: Callable[[], None] = const() | 41 | on_after_epoch: Callable[[], None] = const() |
42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
43 | on_sample: Callable[[int], None] = const() | 43 | on_sample: Callable[[int], None] = const() |
@@ -422,6 +422,7 @@ def train_loop( | |||
422 | global_step_offset: int = 0, | 422 | global_step_offset: int = 0, |
423 | num_epochs: int = 100, | 423 | num_epochs: int = 100, |
424 | gradient_accumulation_steps: int = 1, | 424 | gradient_accumulation_steps: int = 1, |
425 | group_labels: list[str] = [], | ||
425 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 426 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
426 | ): | 427 | ): |
427 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 428 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
@@ -442,10 +443,6 @@ def train_loop( | |||
442 | best_acc = 0.0 | 443 | best_acc = 0.0 |
443 | best_acc_val = 0.0 | 444 | best_acc_val = 0.0 |
444 | 445 | ||
445 | lrs = [] | ||
446 | losses = [] | ||
447 | accs = [] | ||
448 | |||
449 | local_progress_bar = tqdm( | 446 | local_progress_bar = tqdm( |
450 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 447 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
451 | disable=not accelerator.is_local_main_process, | 448 | disable=not accelerator.is_local_main_process, |
@@ -496,6 +493,8 @@ def train_loop( | |||
496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 493 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
497 | local_progress_bar.reset() | 494 | local_progress_bar.reset() |
498 | 495 | ||
496 | logs = {} | ||
497 | |||
499 | with on_train(epoch): | 498 | with on_train(epoch): |
500 | for step, batch in enumerate(train_dataloader): | 499 | for step, batch in enumerate(train_dataloader): |
501 | loss, acc, bsz = loss_step(step, batch, cache) | 500 | loss, acc, bsz = loss_step(step, batch, cache) |
@@ -506,31 +505,36 @@ def train_loop( | |||
506 | avg_loss.update(loss.detach_(), bsz) | 505 | avg_loss.update(loss.detach_(), bsz) |
507 | avg_acc.update(acc.detach_(), bsz) | 506 | avg_acc.update(acc.detach_(), bsz) |
508 | 507 | ||
509 | lr = lr_scheduler.get_last_lr()[0] | ||
510 | if torch.is_tensor(lr): | ||
511 | lr = lr.item() | ||
512 | |||
513 | logs = { | 508 | logs = { |
514 | "train/loss": avg_loss.avg.item(), | 509 | "train/loss": avg_loss.avg.item(), |
515 | "train/acc": avg_acc.avg.item(), | 510 | "train/acc": avg_acc.avg.item(), |
516 | "train/cur_loss": loss.item(), | 511 | "train/cur_loss": loss.item(), |
517 | "train/cur_acc": acc.item(), | 512 | "train/cur_acc": acc.item(), |
518 | "lr": lr, | ||
519 | } | 513 | } |
520 | if isDadaptation: | 514 | |
521 | logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 515 | lrs: dict[str, float] = {} |
516 | for i, lr in enumerate(lr_scheduler.get_last_lr()): | ||
517 | if torch.is_tensor(lr): | ||
518 | lr = lr.item() | ||
519 | label = group_labels[i] if i < len(group_labels) else f"{i}" | ||
520 | logs[f"lr/{label}"] = lr | ||
521 | if isDadaptation: | ||
522 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | ||
523 | logs[f"d*lr/{label}"] = lr | ||
524 | lrs[label] = lr | ||
525 | |||
522 | logs.update(on_log()) | 526 | logs.update(on_log()) |
523 | 527 | ||
524 | local_progress_bar.set_postfix(**logs) | 528 | local_progress_bar.set_postfix(**logs) |
525 | 529 | ||
526 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 530 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
527 | before_optimize_result = on_before_optimize(lr, epoch) | 531 | before_optimize_result = on_before_optimize(epoch) |
528 | 532 | ||
529 | optimizer.step() | 533 | optimizer.step() |
530 | lr_scheduler.step() | 534 | lr_scheduler.step() |
531 | optimizer.zero_grad(set_to_none=True) | 535 | optimizer.zero_grad(set_to_none=True) |
532 | 536 | ||
533 | on_after_optimize(before_optimize_result, lr) | 537 | on_after_optimize(before_optimize_result, lrs) |
534 | 538 | ||
535 | local_progress_bar.update(1) | 539 | local_progress_bar.update(1) |
536 | global_progress_bar.update(1) | 540 | global_progress_bar.update(1) |
@@ -544,15 +548,6 @@ def train_loop( | |||
544 | 548 | ||
545 | accelerator.wait_for_everyone() | 549 | accelerator.wait_for_everyone() |
546 | 550 | ||
547 | if isDadaptation: | ||
548 | lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | ||
549 | else: | ||
550 | lr = lr_scheduler.get_last_lr()[0] | ||
551 | if torch.is_tensor(lr): | ||
552 | lr = lr.item() | ||
553 | |||
554 | lrs.append(lr) | ||
555 | |||
556 | on_after_epoch() | 551 | on_after_epoch() |
557 | 552 | ||
558 | if val_dataloader is not None: | 553 | if val_dataloader is not None: |
@@ -597,9 +592,6 @@ def train_loop( | |||
597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 592 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
598 | on_checkpoint(global_step + global_step_offset, "milestone") | 593 | on_checkpoint(global_step + global_step_offset, "milestone") |
599 | best_acc_val = avg_acc_val.avg.item() | 594 | best_acc_val = avg_acc_val.avg.item() |
600 | |||
601 | losses.append(avg_loss_val.avg.item()) | ||
602 | accs.append(avg_acc_val.avg.item()) | ||
603 | else: | 595 | else: |
604 | if accelerator.is_main_process: | 596 | if accelerator.is_main_process: |
605 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 597 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
@@ -611,9 +603,6 @@ def train_loop( | |||
611 | on_checkpoint(global_step + global_step_offset, "milestone") | 603 | on_checkpoint(global_step + global_step_offset, "milestone") |
612 | best_acc = avg_acc.avg.item() | 604 | best_acc = avg_acc.avg.item() |
613 | 605 | ||
614 | losses.append(avg_loss.avg.item()) | ||
615 | accs.append(avg_acc.avg.item()) | ||
616 | |||
617 | # Create the pipeline using using the trained modules and save it. | 606 | # Create the pipeline using using the trained modules and save it. |
618 | if accelerator.is_main_process: | 607 | if accelerator.is_main_process: |
619 | print("Finished!") | 608 | print("Finished!") |
@@ -626,8 +615,6 @@ def train_loop( | |||
626 | on_checkpoint(global_step + global_step_offset, "end") | 615 | on_checkpoint(global_step + global_step_offset, "end") |
627 | raise KeyboardInterrupt | 616 | raise KeyboardInterrupt |
628 | 617 | ||
629 | return lrs, losses, accs | ||
630 | |||
631 | 618 | ||
632 | def train( | 619 | def train( |
633 | accelerator: Accelerator, | 620 | accelerator: Accelerator, |
@@ -646,6 +633,7 @@ def train( | |||
646 | no_val: bool = False, | 633 | no_val: bool = False, |
647 | num_train_epochs: int = 100, | 634 | num_train_epochs: int = 100, |
648 | gradient_accumulation_steps: int = 1, | 635 | gradient_accumulation_steps: int = 1, |
636 | group_labels: list[str] = [], | ||
649 | sample_frequency: int = 20, | 637 | sample_frequency: int = 20, |
650 | checkpoint_frequency: int = 50, | 638 | checkpoint_frequency: int = 50, |
651 | milestone_checkpoints: bool = True, | 639 | milestone_checkpoints: bool = True, |
@@ -692,7 +680,7 @@ def train( | |||
692 | if accelerator.is_main_process: | 680 | if accelerator.is_main_process: |
693 | accelerator.init_trackers(project) | 681 | accelerator.init_trackers(project) |
694 | 682 | ||
695 | metrics = train_loop( | 683 | train_loop( |
696 | accelerator=accelerator, | 684 | accelerator=accelerator, |
697 | optimizer=optimizer, | 685 | optimizer=optimizer, |
698 | lr_scheduler=lr_scheduler, | 686 | lr_scheduler=lr_scheduler, |
@@ -705,10 +693,9 @@ def train( | |||
705 | global_step_offset=global_step_offset, | 693 | global_step_offset=global_step_offset, |
706 | num_epochs=num_train_epochs, | 694 | num_epochs=num_train_epochs, |
707 | gradient_accumulation_steps=gradient_accumulation_steps, | 695 | gradient_accumulation_steps=gradient_accumulation_steps, |
696 | group_labels=group_labels, | ||
708 | callbacks=callbacks, | 697 | callbacks=callbacks, |
709 | ) | 698 | ) |
710 | 699 | ||
711 | accelerator.end_training() | 700 | accelerator.end_training() |
712 | accelerator.free_memory() | 701 | accelerator.free_memory() |
713 | |||
714 | return metrics | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
106 | with ema_context(): | 106 | with ema_context(): |
107 | yield | 107 | yield |
108 | 108 | ||
109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
79 | tokenizer.eval() | 79 | tokenizer.eval() |
80 | yield | 80 | yield |
81 | 81 | ||
82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
83 | if not pti_mode: | 83 | if not pti_mode: |
84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
86 | unet.parameters(), | ||
87 | text_encoder.text_model.encoder.parameters(), | ||
88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
89 | ), | ||
86 | max_grad_norm | 90 | max_grad_norm |
87 | ) | 91 | ) |
88 | 92 | ||
@@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
96 | 100 | ||
97 | @torch.no_grad() | 101 | @torch.no_grad() |
98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
103 | lr = lrs["emb"] or lrs["0"] | ||
104 | |||
99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
101 | 107 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(lr: float, epoch: int): | 107 | def on_before_optimize(epoch: int): |
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |