diff options
-rw-r--r-- | train_dreambooth.py | 19 | ||||
-rw-r--r-- | train_ti.py | 42 | ||||
-rw-r--r-- | training/functional.py | 3 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
-rw-r--r-- | training/strategy/ti.py | 7 |
5 files changed, 64 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 929310b..90ca467 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -839,7 +839,10 @@ def main(): | |||
839 | 839 | ||
840 | create_optimizer = partial( | 840 | create_optimizer = partial( |
841 | prodigyopt.Prodigy, | 841 | prodigyopt.Prodigy, |
842 | betas=(args.adam_beta1, args.adam_beta2), | ||
842 | weight_decay=args.adam_weight_decay, | 843 | weight_decay=args.adam_weight_decay, |
844 | eps=args.adam_epsilon, | ||
845 | d0=args.dadaptation_d0, | ||
843 | ) | 846 | ) |
844 | 847 | ||
845 | args.learning_rate_unet = 1.0 | 848 | args.learning_rate_unet = 1.0 |
@@ -965,9 +968,23 @@ def main(): | |||
965 | }, | 968 | }, |
966 | { | 969 | { |
967 | "params": ( | 970 | "params": ( |
968 | param for param in text_encoder.parameters() if param.requires_grad | 971 | param |
972 | for param in itertools.chain( | ||
973 | text_encoder.text_model.encoder.parameters(), | ||
974 | text_encoder.text_model.final_layer_norm.parameters(), | ||
975 | ) | ||
976 | if param.requires_grad | ||
977 | ), | ||
978 | "lr": learning_rate_text, | ||
979 | }, | ||
980 | { | ||
981 | "params": ( | ||
982 | param | ||
983 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
984 | if param.requires_grad | ||
969 | ), | 985 | ), |
970 | "lr": learning_rate_text, | 986 | "lr": learning_rate_text, |
987 | "weight_decay": 0, | ||
971 | }, | 988 | }, |
972 | ] | 989 | ] |
973 | group_labels = ["unet", "text"] | 990 | group_labels = ["unet", "text"] |
diff --git a/train_ti.py b/train_ti.py index 1d0cb6f..a7d2924 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -337,7 +337,16 @@ def parse_args(): | |||
337 | "--optimizer", | 337 | "--optimizer", |
338 | type=str, | 338 | type=str, |
339 | default="adan", | 339 | default="adan", |
340 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 340 | choices=[ |
341 | "adam", | ||
342 | "adam8bit", | ||
343 | "adan", | ||
344 | "lion", | ||
345 | "dadam", | ||
346 | "dadan", | ||
347 | "adafactor", | ||
348 | "prodigy", | ||
349 | ], | ||
341 | help="Optimizer to use", | 350 | help="Optimizer to use", |
342 | ) | 351 | ) |
343 | parser.add_argument( | 352 | parser.add_argument( |
@@ -819,6 +828,23 @@ def main(): | |||
819 | eps=args.adam_epsilon, | 828 | eps=args.adam_epsilon, |
820 | d0=args.dadaptation_d0, | 829 | d0=args.dadaptation_d0, |
821 | ) | 830 | ) |
831 | elif args.optimizer == "prodigy": | ||
832 | try: | ||
833 | import prodigyopt | ||
834 | except ImportError: | ||
835 | raise ImportError( | ||
836 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
837 | ) | ||
838 | |||
839 | create_optimizer = partial( | ||
840 | prodigyopt.Prodigy, | ||
841 | betas=(args.adam_beta1, args.adam_beta2), | ||
842 | weight_decay=args.adam_weight_decay, | ||
843 | eps=args.adam_epsilon, | ||
844 | d0=args.dadaptation_d0, | ||
845 | ) | ||
846 | |||
847 | args.learning_rate = 1.0 | ||
822 | else: | 848 | else: |
823 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 849 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
824 | 850 | ||
@@ -959,7 +985,11 @@ def main(): | |||
959 | avg_acc_val = AverageMeter() | 985 | avg_acc_val = AverageMeter() |
960 | 986 | ||
961 | optimizer = create_optimizer( | 987 | optimizer = create_optimizer( |
962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 988 | ( |
989 | param | ||
990 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
991 | if param.requires_grad | ||
992 | ), | ||
963 | lr=args.learning_rate, | 993 | lr=args.learning_rate, |
964 | ) | 994 | ) |
965 | 995 | ||
@@ -973,9 +1003,11 @@ def main(): | |||
973 | 1003 | ||
974 | if response.lower().strip() == "o": | 1004 | if response.lower().strip() == "o": |
975 | if args.learning_rate is not None: | 1005 | if args.learning_rate is not None: |
976 | learning_rate = args.learning_rate * 2 | 1006 | learning_rate = ( |
1007 | args.learning_rate * 2 * (args.cycle_decay**training_iter) | ||
1008 | ) | ||
977 | else: | 1009 | else: |
978 | learning_rate = args.learning_rate | 1010 | learning_rate = args.learning_rate * (args.cycle_decay**training_iter) |
979 | 1011 | ||
980 | if response.lower().strip() == "o": | 1012 | if response.lower().strip() == "o": |
981 | lr_scheduler = "one_cycle" | 1013 | lr_scheduler = "one_cycle" |
@@ -1045,8 +1077,6 @@ def main(): | |||
1045 | ) | 1077 | ) |
1046 | 1078 | ||
1047 | training_iter += 1 | 1079 | training_iter += 1 |
1048 | if learning_rate is not None: | ||
1049 | learning_rate *= args.cycle_decay | ||
1050 | 1080 | ||
1051 | accelerator.end_training() | 1081 | accelerator.end_training() |
1052 | 1082 | ||
diff --git a/training/functional.py b/training/functional.py index 8917eb7..b60afe3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -786,7 +786,4 @@ def train( | |||
786 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 786 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
787 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 787 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
788 | 788 | ||
789 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
790 | unet.forward = MethodType(unet.forward, unet) | ||
791 | |||
792 | accelerator.free_memory() | 789 | accelerator.free_memory() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( | |||
154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
156 | 156 | ||
157 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
159 | |||
157 | text_encoder_.text_model.embeddings.persist(False) | 160 | text_encoder_.text_model.embeddings.persist(False) |
158 | 161 | ||
159 | with ema_context(): | 162 | with ema_context(): |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,4 +1,5 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from types import MethodType | ||
2 | from functools import partial | 3 | from functools import partial |
3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
4 | from pathlib import Path | 5 | from pathlib import Path |
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
141 | 142 | ||
143 | if postfix == "end": | ||
144 | text_encoder_ = accelerator.unwrap_model( | ||
145 | text_encoder, keep_fp32_wrapper=False | ||
146 | ) | ||
147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
148 | |||
142 | with ema_context(): | 149 | with ema_context(): |
143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |