diff options
| -rw-r--r-- | train_dreambooth.py | 19 | ||||
| -rw-r--r-- | train_ti.py | 42 | ||||
| -rw-r--r-- | training/functional.py | 3 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
| -rw-r--r-- | training/strategy/ti.py | 7 |
5 files changed, 64 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 929310b..90ca467 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -839,7 +839,10 @@ def main(): | |||
| 839 | 839 | ||
| 840 | create_optimizer = partial( | 840 | create_optimizer = partial( |
| 841 | prodigyopt.Prodigy, | 841 | prodigyopt.Prodigy, |
| 842 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 842 | weight_decay=args.adam_weight_decay, | 843 | weight_decay=args.adam_weight_decay, |
| 844 | eps=args.adam_epsilon, | ||
| 845 | d0=args.dadaptation_d0, | ||
| 843 | ) | 846 | ) |
| 844 | 847 | ||
| 845 | args.learning_rate_unet = 1.0 | 848 | args.learning_rate_unet = 1.0 |
| @@ -965,9 +968,23 @@ def main(): | |||
| 965 | }, | 968 | }, |
| 966 | { | 969 | { |
| 967 | "params": ( | 970 | "params": ( |
| 968 | param for param in text_encoder.parameters() if param.requires_grad | 971 | param |
| 972 | for param in itertools.chain( | ||
| 973 | text_encoder.text_model.encoder.parameters(), | ||
| 974 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 975 | ) | ||
| 976 | if param.requires_grad | ||
| 977 | ), | ||
| 978 | "lr": learning_rate_text, | ||
| 979 | }, | ||
| 980 | { | ||
| 981 | "params": ( | ||
| 982 | param | ||
| 983 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 984 | if param.requires_grad | ||
| 969 | ), | 985 | ), |
| 970 | "lr": learning_rate_text, | 986 | "lr": learning_rate_text, |
| 987 | "weight_decay": 0, | ||
| 971 | }, | 988 | }, |
| 972 | ] | 989 | ] |
| 973 | group_labels = ["unet", "text"] | 990 | group_labels = ["unet", "text"] |
diff --git a/train_ti.py b/train_ti.py index 1d0cb6f..a7d2924 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -337,7 +337,16 @@ def parse_args(): | |||
| 337 | "--optimizer", | 337 | "--optimizer", |
| 338 | type=str, | 338 | type=str, |
| 339 | default="adan", | 339 | default="adan", |
| 340 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 340 | choices=[ |
| 341 | "adam", | ||
| 342 | "adam8bit", | ||
| 343 | "adan", | ||
| 344 | "lion", | ||
| 345 | "dadam", | ||
| 346 | "dadan", | ||
| 347 | "adafactor", | ||
| 348 | "prodigy", | ||
| 349 | ], | ||
| 341 | help="Optimizer to use", | 350 | help="Optimizer to use", |
| 342 | ) | 351 | ) |
| 343 | parser.add_argument( | 352 | parser.add_argument( |
| @@ -819,6 +828,23 @@ def main(): | |||
| 819 | eps=args.adam_epsilon, | 828 | eps=args.adam_epsilon, |
| 820 | d0=args.dadaptation_d0, | 829 | d0=args.dadaptation_d0, |
| 821 | ) | 830 | ) |
| 831 | elif args.optimizer == "prodigy": | ||
| 832 | try: | ||
| 833 | import prodigyopt | ||
| 834 | except ImportError: | ||
| 835 | raise ImportError( | ||
| 836 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 837 | ) | ||
| 838 | |||
| 839 | create_optimizer = partial( | ||
| 840 | prodigyopt.Prodigy, | ||
| 841 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 842 | weight_decay=args.adam_weight_decay, | ||
| 843 | eps=args.adam_epsilon, | ||
| 844 | d0=args.dadaptation_d0, | ||
| 845 | ) | ||
| 846 | |||
| 847 | args.learning_rate = 1.0 | ||
| 822 | else: | 848 | else: |
| 823 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 849 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 824 | 850 | ||
| @@ -959,7 +985,11 @@ def main(): | |||
| 959 | avg_acc_val = AverageMeter() | 985 | avg_acc_val = AverageMeter() |
| 960 | 986 | ||
| 961 | optimizer = create_optimizer( | 987 | optimizer = create_optimizer( |
| 962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 988 | ( |
| 989 | param | ||
| 990 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 991 | if param.requires_grad | ||
| 992 | ), | ||
| 963 | lr=args.learning_rate, | 993 | lr=args.learning_rate, |
| 964 | ) | 994 | ) |
| 965 | 995 | ||
| @@ -973,9 +1003,11 @@ def main(): | |||
| 973 | 1003 | ||
| 974 | if response.lower().strip() == "o": | 1004 | if response.lower().strip() == "o": |
| 975 | if args.learning_rate is not None: | 1005 | if args.learning_rate is not None: |
| 976 | learning_rate = args.learning_rate * 2 | 1006 | learning_rate = ( |
| 1007 | args.learning_rate * 2 * (args.cycle_decay**training_iter) | ||
| 1008 | ) | ||
| 977 | else: | 1009 | else: |
| 978 | learning_rate = args.learning_rate | 1010 | learning_rate = args.learning_rate * (args.cycle_decay**training_iter) |
| 979 | 1011 | ||
| 980 | if response.lower().strip() == "o": | 1012 | if response.lower().strip() == "o": |
| 981 | lr_scheduler = "one_cycle" | 1013 | lr_scheduler = "one_cycle" |
| @@ -1045,8 +1077,6 @@ def main(): | |||
| 1045 | ) | 1077 | ) |
| 1046 | 1078 | ||
| 1047 | training_iter += 1 | 1079 | training_iter += 1 |
| 1048 | if learning_rate is not None: | ||
| 1049 | learning_rate *= args.cycle_decay | ||
| 1050 | 1080 | ||
| 1051 | accelerator.end_training() | 1081 | accelerator.end_training() |
| 1052 | 1082 | ||
diff --git a/training/functional.py b/training/functional.py index 8917eb7..b60afe3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -786,7 +786,4 @@ def train( | |||
| 786 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 786 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 787 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 787 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 788 | 788 | ||
| 789 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
| 790 | unet.forward = MethodType(unet.forward, unet) | ||
| 791 | |||
| 792 | accelerator.free_memory() | 789 | accelerator.free_memory() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( | |||
| 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 156 | 156 | ||
| 157 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
| 159 | |||
| 157 | text_encoder_.text_model.embeddings.persist(False) | 160 | text_encoder_.text_model.embeddings.persist(False) |
| 158 | 161 | ||
| 159 | with ema_context(): | 162 | with ema_context(): |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from types import MethodType | ||
| 2 | from functools import partial | 3 | from functools import partial |
| 3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| @@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
| 139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
| 140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
| 141 | 142 | ||
| 143 | if postfix == "end": | ||
| 144 | text_encoder_ = accelerator.unwrap_model( | ||
| 145 | text_encoder, keep_fp32_wrapper=False | ||
| 146 | ) | ||
| 147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 148 | |||
| 142 | with ema_context(): | 149 | with ema_context(): |
| 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
| 144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |
