summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py19
-rw-r--r--train_ti.py42
-rw-r--r--training/functional.py3
-rw-r--r--training/strategy/dreambooth.py3
-rw-r--r--training/strategy/ti.py7
5 files changed, 64 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 929310b..90ca467 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -839,7 +839,10 @@ def main():
839 839
840 create_optimizer = partial( 840 create_optimizer = partial(
841 prodigyopt.Prodigy, 841 prodigyopt.Prodigy,
842 betas=(args.adam_beta1, args.adam_beta2),
842 weight_decay=args.adam_weight_decay, 843 weight_decay=args.adam_weight_decay,
844 eps=args.adam_epsilon,
845 d0=args.dadaptation_d0,
843 ) 846 )
844 847
845 args.learning_rate_unet = 1.0 848 args.learning_rate_unet = 1.0
@@ -965,9 +968,23 @@ def main():
965 }, 968 },
966 { 969 {
967 "params": ( 970 "params": (
968 param for param in text_encoder.parameters() if param.requires_grad 971 param
972 for param in itertools.chain(
973 text_encoder.text_model.encoder.parameters(),
974 text_encoder.text_model.final_layer_norm.parameters(),
975 )
976 if param.requires_grad
977 ),
978 "lr": learning_rate_text,
979 },
980 {
981 "params": (
982 param
983 for param in text_encoder.text_model.embeddings.token_embedding.parameters()
984 if param.requires_grad
969 ), 985 ),
970 "lr": learning_rate_text, 986 "lr": learning_rate_text,
987 "weight_decay": 0,
971 }, 988 },
972 ] 989 ]
973 group_labels = ["unet", "text"] 990 group_labels = ["unet", "text"]
diff --git a/train_ti.py b/train_ti.py
index 1d0cb6f..a7d2924 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -337,7 +337,16 @@ def parse_args():
337 "--optimizer", 337 "--optimizer",
338 type=str, 338 type=str,
339 default="adan", 339 default="adan",
340 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 340 choices=[
341 "adam",
342 "adam8bit",
343 "adan",
344 "lion",
345 "dadam",
346 "dadan",
347 "adafactor",
348 "prodigy",
349 ],
341 help="Optimizer to use", 350 help="Optimizer to use",
342 ) 351 )
343 parser.add_argument( 352 parser.add_argument(
@@ -819,6 +828,23 @@ def main():
819 eps=args.adam_epsilon, 828 eps=args.adam_epsilon,
820 d0=args.dadaptation_d0, 829 d0=args.dadaptation_d0,
821 ) 830 )
831 elif args.optimizer == "prodigy":
832 try:
833 import prodigyopt
834 except ImportError:
835 raise ImportError(
836 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
837 )
838
839 create_optimizer = partial(
840 prodigyopt.Prodigy,
841 betas=(args.adam_beta1, args.adam_beta2),
842 weight_decay=args.adam_weight_decay,
843 eps=args.adam_epsilon,
844 d0=args.dadaptation_d0,
845 )
846
847 args.learning_rate = 1.0
822 else: 848 else:
823 raise ValueError(f'Unknown --optimizer "{args.optimizer}"') 849 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
824 850
@@ -959,7 +985,11 @@ def main():
959 avg_acc_val = AverageMeter() 985 avg_acc_val = AverageMeter()
960 986
961 optimizer = create_optimizer( 987 optimizer = create_optimizer(
962 text_encoder.text_model.embeddings.token_embedding.parameters(), 988 (
989 param
990 for param in text_encoder.text_model.embeddings.token_embedding.parameters()
991 if param.requires_grad
992 ),
963 lr=args.learning_rate, 993 lr=args.learning_rate,
964 ) 994 )
965 995
@@ -973,9 +1003,11 @@ def main():
973 1003
974 if response.lower().strip() == "o": 1004 if response.lower().strip() == "o":
975 if args.learning_rate is not None: 1005 if args.learning_rate is not None:
976 learning_rate = args.learning_rate * 2 1006 learning_rate = (
1007 args.learning_rate * 2 * (args.cycle_decay**training_iter)
1008 )
977 else: 1009 else:
978 learning_rate = args.learning_rate 1010 learning_rate = args.learning_rate * (args.cycle_decay**training_iter)
979 1011
980 if response.lower().strip() == "o": 1012 if response.lower().strip() == "o":
981 lr_scheduler = "one_cycle" 1013 lr_scheduler = "one_cycle"
@@ -1045,8 +1077,6 @@ def main():
1045 ) 1077 )
1046 1078
1047 training_iter += 1 1079 training_iter += 1
1048 if learning_rate is not None:
1049 learning_rate *= args.cycle_decay
1050 1080
1051 accelerator.end_training() 1081 accelerator.end_training()
1052 1082
diff --git a/training/functional.py b/training/functional.py
index 8917eb7..b60afe3 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -786,7 +786,4 @@ def train(
786 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 786 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
787 accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 787 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
788 788
789 text_encoder.forward = MethodType(text_encoder.forward, text_encoder)
790 unet.forward = MethodType(unet.forward, unet)
791
792 accelerator.free_memory() 789 accelerator.free_memory()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 3d1abf7..7e67589 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks(
154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
156 156
157 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
158 unet_.forward = MethodType(unet_.forward, unet_)
159
157 text_encoder_.text_model.embeddings.persist(False) 160 text_encoder_.text_model.embeddings.persist(False)
158 161
159 with ema_context(): 162 with ema_context():
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 7373982..f37dfb4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks(
139 def on_checkpoint(step, postfix): 140 def on_checkpoint(step, postfix):
140 print(f"Saving checkpoint for step {step}...") 141 print(f"Saving checkpoint for step {step}...")
141 142
143 if postfix == "end":
144 text_encoder_ = accelerator.unwrap_model(
145 text_encoder, keep_fp32_wrapper=False
146 )
147 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
148
142 with ema_context(): 149 with ema_context():
143 for token, ids in zip(placeholder_tokens, placeholder_token_ids): 150 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
144 text_encoder.text_model.embeddings.save_embed( 151 text_encoder.text_model.embeddings.save_embed(