diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
| commit | 71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch) | |
| tree | 29c704ca549a4c4323403b6cbb0e62f54040ae22 | |
| parent | Added option to use constant LR on cycles > 1 (diff) | |
| download | textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2 textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip | |
Improved automation caps
| -rw-r--r-- | train_lora.py | 53 | ||||
| -rw-r--r-- | train_ti.py | 53 | ||||
| -rw-r--r-- | training/functional.py | 17 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
| -rw-r--r-- | training/strategy/lora.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 23 |
6 files changed, 100 insertions, 54 deletions
diff --git a/train_lora.py b/train_lora.py index 4d4c16a..ba5aee1 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -84,9 +84,9 @@ def parse_args(): | |||
| 84 | ) | 84 | ) |
| 85 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--auto_cycles", | 86 | "--auto_cycles", |
| 87 | type=int, | 87 | type=str, |
| 88 | default=1, | 88 | default="o", |
| 89 | help="How many cycles to run automatically." | 89 | help="Cycles to run automatically." |
| 90 | ) | 90 | ) |
| 91 | parser.add_argument( | 91 | parser.add_argument( |
| 92 | "--cycle_decay", | 92 | "--cycle_decay", |
| @@ -95,11 +95,6 @@ def parse_args(): | |||
| 95 | help="Learning rate decay per cycle." | 95 | help="Learning rate decay per cycle." |
| 96 | ) | 96 | ) |
| 97 | parser.add_argument( | 97 | parser.add_argument( |
| 98 | "--cycle_constant", | ||
| 99 | action="store_true", | ||
| 100 | help="Use constant LR on cycles > 1." | ||
| 101 | ) | ||
| 102 | parser.add_argument( | ||
| 103 | "--placeholder_tokens", | 98 | "--placeholder_tokens", |
| 104 | type=str, | 99 | type=str, |
| 105 | nargs='*', | 100 | nargs='*', |
| @@ -920,7 +915,6 @@ def main(): | |||
| 920 | annealing_func=args.lr_annealing_func, | 915 | annealing_func=args.lr_annealing_func, |
| 921 | warmup_exp=args.lr_warmup_exp, | 916 | warmup_exp=args.lr_warmup_exp, |
| 922 | annealing_exp=args.lr_annealing_exp, | 917 | annealing_exp=args.lr_annealing_exp, |
| 923 | cycles=args.lr_cycles, | ||
| 924 | end_lr=1e2, | 918 | end_lr=1e2, |
| 925 | mid_point=args.lr_mid_point, | 919 | mid_point=args.lr_mid_point, |
| 926 | ) | 920 | ) |
| @@ -964,20 +958,38 @@ def main(): | |||
| 964 | 958 | ||
| 965 | lora_sample_output_dir = output_dir / lora_project / "samples" | 959 | lora_sample_output_dir = output_dir / lora_project / "samples" |
| 966 | 960 | ||
| 961 | auto_cycles = list(args.auto_cycles) | ||
| 962 | lr_scheduler = args.lr_scheduler | ||
| 963 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 964 | lr_cycles = args.lr_cycles | ||
| 965 | |||
| 967 | while True: | 966 | while True: |
| 968 | if training_iter >= args.auto_cycles: | 967 | if len(auto_cycles) != 0: |
| 969 | response = input("Run another cycle? [y/n] ") | 968 | response = auto_cycles.pop(0) |
| 970 | if response.lower().strip() == "n": | 969 | else: |
| 971 | break | 970 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
| 971 | |||
| 972 | if response.lower().strip() == "o": | ||
| 973 | lr_scheduler = "one_cycle" | ||
| 974 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 975 | lr_cycles = args.lr_cycles | ||
| 976 | if response.lower().strip() == "w": | ||
| 977 | lr_scheduler = "constant" | ||
| 978 | lr_warmup_epochs = num_train_epochs | ||
| 979 | if response.lower().strip() == "c": | ||
| 980 | lr_scheduler = "constant" | ||
| 981 | lr_warmup_epochs = 0 | ||
| 982 | if response.lower().strip() == "d": | ||
| 983 | lr_scheduler = "cosine" | ||
| 984 | lr_warmup_epochs = 0 | ||
| 985 | lr_cycles = 1 | ||
| 986 | elif response.lower().strip() == "s": | ||
| 987 | break | ||
| 972 | 988 | ||
| 973 | print("") | 989 | print("") |
| 974 | print(f"============ LoRA cycle {training_iter + 1} ============") | 990 | print(f"============ LoRA cycle {training_iter + 1} ============") |
| 975 | print("") | 991 | print("") |
| 976 | 992 | ||
| 977 | if args.cycle_constant and training_iter == 1: | ||
| 978 | args.lr_scheduler = "constant" | ||
| 979 | args.lr_warmup_epochs = 0 | ||
| 980 | |||
| 981 | params_to_optimize = [] | 993 | params_to_optimize = [] |
| 982 | 994 | ||
| 983 | if len(args.placeholder_tokens) != 0: | 995 | if len(args.placeholder_tokens) != 0: |
| @@ -1012,12 +1024,13 @@ def main(): | |||
| 1012 | lora_optimizer = create_optimizer(params_to_optimize) | 1024 | lora_optimizer = create_optimizer(params_to_optimize) |
| 1013 | 1025 | ||
| 1014 | lora_lr_scheduler = create_lr_scheduler( | 1026 | lora_lr_scheduler = create_lr_scheduler( |
| 1015 | args.lr_scheduler, | 1027 | lr_scheduler, |
| 1016 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1028 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1017 | optimizer=lora_optimizer, | 1029 | optimizer=lora_optimizer, |
| 1018 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 1030 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
| 1019 | train_epochs=num_train_epochs, | 1031 | train_epochs=num_train_epochs, |
| 1020 | warmup_epochs=args.lr_warmup_epochs, | 1032 | cycles=lr_cycles, |
| 1033 | warmup_epochs=lr_warmup_epochs, | ||
| 1021 | ) | 1034 | ) |
| 1022 | 1035 | ||
| 1023 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1036 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
| @@ -1031,7 +1044,7 @@ def main(): | |||
| 1031 | num_train_epochs=num_train_epochs, | 1044 | num_train_epochs=num_train_epochs, |
| 1032 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1045 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1033 | global_step_offset=training_iter * num_train_steps, | 1046 | global_step_offset=training_iter * num_train_steps, |
| 1034 | initial_samples=training_iter == 0, | 1047 | cycle=training_iter, |
| 1035 | # -- | 1048 | # -- |
| 1036 | group_labels=group_labels, | 1049 | group_labels=group_labels, |
| 1037 | sample_output_dir=lora_sample_output_dir, | 1050 | sample_output_dir=lora_sample_output_dir, |
diff --git a/train_ti.py b/train_ti.py index c452269..880320f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -68,9 +68,9 @@ def parse_args(): | |||
| 68 | ) | 68 | ) |
| 69 | parser.add_argument( | 69 | parser.add_argument( |
| 70 | "--auto_cycles", | 70 | "--auto_cycles", |
| 71 | type=int, | 71 | type=str, |
| 72 | default=1, | 72 | default="o", |
| 73 | help="How many cycles to run automatically." | 73 | help="Cycles to run automatically." |
| 74 | ) | 74 | ) |
| 75 | parser.add_argument( | 75 | parser.add_argument( |
| 76 | "--cycle_decay", | 76 | "--cycle_decay", |
| @@ -79,11 +79,6 @@ def parse_args(): | |||
| 79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
| 80 | ) | 80 | ) |
| 81 | parser.add_argument( | 81 | parser.add_argument( |
| 82 | "--cycle_constant", | ||
| 83 | action="store_true", | ||
| 84 | help="Use constant LR on cycles > 1." | ||
| 85 | ) | ||
| 86 | parser.add_argument( | ||
| 87 | "--placeholder_tokens", | 82 | "--placeholder_tokens", |
| 88 | type=str, | 83 | type=str, |
| 89 | nargs='*', | 84 | nargs='*', |
| @@ -921,27 +916,45 @@ def main(): | |||
| 921 | 916 | ||
| 922 | sample_output_dir = output_dir / project / "samples" | 917 | sample_output_dir = output_dir / project / "samples" |
| 923 | 918 | ||
| 919 | auto_cycles = list(args.auto_cycles) | ||
| 920 | lr_scheduler = args.lr_scheduler | ||
| 921 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 922 | lr_cycles = args.lr_cycles | ||
| 923 | |||
| 924 | while True: | 924 | while True: |
| 925 | if training_iter >= args.auto_cycles: | 925 | if len(auto_cycles) != 0: |
| 926 | response = input("Run another cycle? [y/n] ") | 926 | response = auto_cycles.pop(0) |
| 927 | if response.lower().strip() == "n": | 927 | else: |
| 928 | break | 928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
| 929 | |||
| 930 | if response.lower().strip() == "o": | ||
| 931 | lr_scheduler = "one_cycle" | ||
| 932 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 933 | lr_cycles = args.lr_cycles | ||
| 934 | if response.lower().strip() == "w": | ||
| 935 | lr_scheduler = "constant" | ||
| 936 | lr_warmup_epochs = num_train_epochs | ||
| 937 | if response.lower().strip() == "c": | ||
| 938 | lr_scheduler = "constant" | ||
| 939 | lr_warmup_epochs = 0 | ||
| 940 | if response.lower().strip() == "d": | ||
| 941 | lr_scheduler = "cosine" | ||
| 942 | lr_warmup_epochs = 0 | ||
| 943 | lr_cycles = 1 | ||
| 944 | elif response.lower().strip() == "s": | ||
| 945 | break | ||
| 929 | 946 | ||
| 930 | print("") | 947 | print("") |
| 931 | print(f"------------ TI cycle {training_iter + 1} ------------") | 948 | print(f"------------ TI cycle {training_iter + 1} ------------") |
| 932 | print("") | 949 | print("") |
| 933 | 950 | ||
| 934 | if args.cycle_constant and training_iter == 1: | ||
| 935 | args.lr_scheduler = "constant" | ||
| 936 | args.lr_warmup_epochs = 0 | ||
| 937 | |||
| 938 | optimizer = create_optimizer( | 951 | optimizer = create_optimizer( |
| 939 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 952 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 940 | lr=learning_rate, | 953 | lr=learning_rate, |
| 941 | ) | 954 | ) |
| 942 | 955 | ||
| 943 | lr_scheduler = get_scheduler( | 956 | lr_scheduler = get_scheduler( |
| 944 | args.lr_scheduler, | 957 | lr_scheduler, |
| 945 | optimizer=optimizer, | 958 | optimizer=optimizer, |
| 946 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 959 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 947 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 960 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| @@ -950,10 +963,10 @@ def main(): | |||
| 950 | annealing_func=args.lr_annealing_func, | 963 | annealing_func=args.lr_annealing_func, |
| 951 | warmup_exp=args.lr_warmup_exp, | 964 | warmup_exp=args.lr_warmup_exp, |
| 952 | annealing_exp=args.lr_annealing_exp, | 965 | annealing_exp=args.lr_annealing_exp, |
| 953 | cycles=args.lr_cycles, | 966 | cycles=lr_cycles, |
| 954 | end_lr=1e3, | 967 | end_lr=1e3, |
| 955 | train_epochs=num_train_epochs, | 968 | train_epochs=num_train_epochs, |
| 956 | warmup_epochs=args.lr_warmup_epochs, | 969 | warmup_epochs=lr_warmup_epochs, |
| 957 | mid_point=args.lr_mid_point, | 970 | mid_point=args.lr_mid_point, |
| 958 | ) | 971 | ) |
| 959 | 972 | ||
| @@ -966,7 +979,7 @@ def main(): | |||
| 966 | lr_scheduler=lr_scheduler, | 979 | lr_scheduler=lr_scheduler, |
| 967 | num_train_epochs=num_train_epochs, | 980 | num_train_epochs=num_train_epochs, |
| 968 | global_step_offset=training_iter * num_train_steps, | 981 | global_step_offset=training_iter * num_train_steps, |
| 969 | initial_samples=training_iter == 0, | 982 | cycle=training_iter, |
| 970 | # -- | 983 | # -- |
| 971 | group_labels=["emb"], | 984 | group_labels=["emb"], |
| 972 | checkpoint_output_dir=checkpoint_output_dir, | 985 | checkpoint_output_dir=checkpoint_output_dir, |
diff --git a/training/functional.py b/training/functional.py index 2da0f69..ebc40de 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -42,7 +42,7 @@ class TrainingCallbacks(): | |||
| 42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() | 42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
| 43 | on_after_epoch: Callable[[], None] = const() | 43 | on_after_epoch: Callable[[], None] = const() |
| 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| 45 | on_sample: Callable[[int], None] = const() | 45 | on_sample: Callable[[int, int], None] = const() |
| 46 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() |
| 47 | 47 | ||
| 48 | 48 | ||
| @@ -96,6 +96,7 @@ def save_samples( | |||
| 96 | output_dir: Path, | 96 | output_dir: Path, |
| 97 | seed: int, | 97 | seed: int, |
| 98 | step: int, | 98 | step: int, |
| 99 | cycle: int = 1, | ||
| 99 | batch_size: int = 1, | 100 | batch_size: int = 1, |
| 100 | num_batches: int = 1, | 101 | num_batches: int = 1, |
| 101 | num_steps: int = 20, | 102 | num_steps: int = 20, |
| @@ -125,7 +126,7 @@ def save_samples( | |||
| 125 | 126 | ||
| 126 | for pool, data, gen in datasets: | 127 | for pool, data, gen in datasets: |
| 127 | all_samples = [] | 128 | all_samples = [] |
| 128 | file_path = output_dir / pool / f"step_{step}.jpg" | 129 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
| 129 | file_path.parent.mkdir(parents=True, exist_ok=True) | 130 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 130 | 131 | ||
| 131 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 132 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) |
| @@ -455,7 +456,7 @@ def train_loop( | |||
| 455 | sample_frequency: int = 10, | 456 | sample_frequency: int = 10, |
| 456 | checkpoint_frequency: int = 50, | 457 | checkpoint_frequency: int = 50, |
| 457 | milestone_checkpoints: bool = True, | 458 | milestone_checkpoints: bool = True, |
| 458 | initial_samples: bool = True, | 459 | cycle: int = 1, |
| 459 | global_step_offset: int = 0, | 460 | global_step_offset: int = 0, |
| 460 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
| 461 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
| @@ -518,12 +519,12 @@ def train_loop( | |||
| 518 | try: | 519 | try: |
| 519 | for epoch in range(num_epochs): | 520 | for epoch in range(num_epochs): |
| 520 | if accelerator.is_main_process: | 521 | if accelerator.is_main_process: |
| 521 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): | 522 | if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0): |
| 522 | local_progress_bar.clear() | 523 | local_progress_bar.clear() |
| 523 | global_progress_bar.clear() | 524 | global_progress_bar.clear() |
| 524 | 525 | ||
| 525 | with on_eval(): | 526 | with on_eval(): |
| 526 | on_sample(global_step) | 527 | on_sample(cycle, global_step) |
| 527 | 528 | ||
| 528 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 529 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 529 | local_progress_bar.clear() | 530 | local_progress_bar.clear() |
| @@ -648,7 +649,7 @@ def train_loop( | |||
| 648 | if accelerator.is_main_process: | 649 | if accelerator.is_main_process: |
| 649 | print("Finished!") | 650 | print("Finished!") |
| 650 | with on_eval(): | 651 | with on_eval(): |
| 651 | on_sample(global_step) | 652 | on_sample(cycle, global_step) |
| 652 | on_checkpoint(global_step, "end") | 653 | on_checkpoint(global_step, "end") |
| 653 | 654 | ||
| 654 | except KeyboardInterrupt: | 655 | except KeyboardInterrupt: |
| @@ -680,7 +681,7 @@ def train( | |||
| 680 | sample_frequency: int = 20, | 681 | sample_frequency: int = 20, |
| 681 | checkpoint_frequency: int = 50, | 682 | checkpoint_frequency: int = 50, |
| 682 | milestone_checkpoints: bool = True, | 683 | milestone_checkpoints: bool = True, |
| 683 | initial_samples: bool = True, | 684 | cycle: int = 1, |
| 684 | global_step_offset: int = 0, | 685 | global_step_offset: int = 0, |
| 685 | guidance_scale: float = 0.0, | 686 | guidance_scale: float = 0.0, |
| 686 | prior_loss_weight: float = 1.0, | 687 | prior_loss_weight: float = 1.0, |
| @@ -731,7 +732,7 @@ def train( | |||
| 731 | sample_frequency=sample_frequency, | 732 | sample_frequency=sample_frequency, |
| 732 | checkpoint_frequency=checkpoint_frequency, | 733 | checkpoint_frequency=checkpoint_frequency, |
| 733 | milestone_checkpoints=milestone_checkpoints, | 734 | milestone_checkpoints=milestone_checkpoints, |
| 734 | initial_samples=initial_samples, | 735 | cycle=cycle, |
| 735 | global_step_offset=global_step_offset, | 736 | global_step_offset=global_step_offset, |
| 736 | num_epochs=num_train_epochs, | 737 | num_epochs=num_train_epochs, |
| 737 | gradient_accumulation_steps=gradient_accumulation_steps, | 738 | gradient_accumulation_steps=gradient_accumulation_steps, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 4ae28b7..e6fcc89 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks( | |||
| 148 | torch.cuda.empty_cache() | 148 | torch.cuda.empty_cache() |
| 149 | 149 | ||
| 150 | @torch.no_grad() | 150 | @torch.no_grad() |
| 151 | def on_sample(step): | 151 | def on_sample(cycle, step): |
| 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 154 | 154 | ||
| @@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks( | |||
| 158 | unet_.to(dtype=weight_dtype) | 158 | unet_.to(dtype=weight_dtype) |
| 159 | text_encoder_.to(dtype=weight_dtype) | 159 | text_encoder_.to(dtype=weight_dtype) |
| 160 | 160 | ||
| 161 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 161 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
| 162 | 162 | ||
| 163 | unet_.to(dtype=orig_unet_dtype) | 163 | unet_.to(dtype=orig_unet_dtype) |
| 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -146,11 +146,11 @@ def lora_strategy_callbacks( | |||
| 146 | torch.cuda.empty_cache() | 146 | torch.cuda.empty_cache() |
| 147 | 147 | ||
| 148 | @torch.no_grad() | 148 | @torch.no_grad() |
| 149 | def on_sample(step): | 149 | def on_sample(cycle, step): |
| 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 152 | 152 | ||
| 153 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 153 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
| 154 | 154 | ||
| 155 | del unet_, text_encoder_ | 155 | del unet_, text_encoder_ |
| 156 | 156 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f0b84b5..6bbff64 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,10 +104,28 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(epoch: int): | ||
| 108 | if use_emb_decay: | ||
| 109 | params = [ | ||
| 110 | p | ||
| 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 112 | if p.grad is not None | ||
| 113 | ] | ||
| 114 | return torch.stack(params) if len(params) != 0 else None | ||
| 115 | |||
| 116 | @torch.no_grad() | ||
| 107 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 108 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 109 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
| 110 | 120 | ||
| 121 | if use_emb_decay and w is not None: | ||
| 122 | lr = lrs["emb"] or lrs["0"] | ||
| 123 | lambda_ = emb_decay * lr | ||
| 124 | |||
| 125 | if lambda_ != 0: | ||
| 126 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 128 | |||
| 111 | def on_log(): | 129 | def on_log(): |
| 112 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: |
| 113 | return {"ema_decay": ema_embeddings.decay} | 131 | return {"ema_decay": ema_embeddings.decay} |
| @@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks( | |||
| 125 | ) | 143 | ) |
| 126 | 144 | ||
| 127 | @torch.no_grad() | 145 | @torch.no_grad() |
| 128 | def on_sample(step): | 146 | def on_sample(cycle, step): |
| 129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 147 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 131 | 149 | ||
| @@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks( | |||
| 135 | unet_.to(dtype=weight_dtype) | 153 | unet_.to(dtype=weight_dtype) |
| 136 | text_encoder_.to(dtype=weight_dtype) | 154 | text_encoder_.to(dtype=weight_dtype) |
| 137 | 155 | ||
| 138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 156 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
| 139 | 157 | ||
| 140 | unet_.to(dtype=orig_unet_dtype) | 158 | unet_.to(dtype=orig_unet_dtype) |
| 141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 159 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
| @@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks( | |||
| 148 | return TrainingCallbacks( | 166 | return TrainingCallbacks( |
| 149 | on_train=on_train, | 167 | on_train=on_train, |
| 150 | on_eval=on_eval, | 168 | on_eval=on_eval, |
| 169 | on_before_optimize=on_before_optimize, | ||
| 151 | on_after_optimize=on_after_optimize, | 170 | on_after_optimize=on_after_optimize, |
| 152 | on_log=on_log, | 171 | on_log=on_log, |
| 153 | on_checkpoint=on_checkpoint, | 172 | on_checkpoint=on_checkpoint, |
