diff options
-rw-r--r-- | train_lora.py | 73 | ||||
-rw-r--r-- | training/functional.py | 9 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 17 | ||||
-rw-r--r-- | training/strategy/lora.py | 49 | ||||
-rw-r--r-- | training/strategy/ti.py | 22 |
5 files changed, 104 insertions, 66 deletions
diff --git a/train_lora.py b/train_lora.py index 8dd3c86..fa24cee 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -11,6 +11,7 @@ import torch.utils.checkpoint | |||
11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from peft import LoraConfig, LoraModel | ||
14 | from slugify import slugify | 15 | from slugify import slugify |
15 | 16 | ||
16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
@@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy | |||
21 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
22 | from training.util import save_args | 23 | from training.util import save_args |
23 | 24 | ||
25 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
26 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
27 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
28 | |||
29 | |||
24 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
25 | 31 | ||
26 | 32 | ||
@@ -176,6 +182,54 @@ def parse_args(): | |||
176 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 182 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
177 | ) | 183 | ) |
178 | parser.add_argument( | 184 | parser.add_argument( |
185 | "--lora_r", | ||
186 | type=int, | ||
187 | default=8, | ||
188 | help="Lora rank, only used if use_lora is True" | ||
189 | ) | ||
190 | parser.add_argument( | ||
191 | "--lora_alpha", | ||
192 | type=int, | ||
193 | default=32, | ||
194 | help="Lora alpha, only used if use_lora is True" | ||
195 | ) | ||
196 | parser.add_argument( | ||
197 | "--lora_dropout", | ||
198 | type=float, | ||
199 | default=0.0, | ||
200 | help="Lora dropout, only used if use_lora is True" | ||
201 | ) | ||
202 | parser.add_argument( | ||
203 | "--lora_bias", | ||
204 | type=str, | ||
205 | default="none", | ||
206 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", | ||
207 | ) | ||
208 | parser.add_argument( | ||
209 | "--lora_text_encoder_r", | ||
210 | type=int, | ||
211 | default=8, | ||
212 | help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
213 | ) | ||
214 | parser.add_argument( | ||
215 | "--lora_text_encoder_alpha", | ||
216 | type=int, | ||
217 | default=32, | ||
218 | help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
219 | ) | ||
220 | parser.add_argument( | ||
221 | "--lora_text_encoder_dropout", | ||
222 | type=float, | ||
223 | default=0.0, | ||
224 | help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
225 | ) | ||
226 | parser.add_argument( | ||
227 | "--lora_text_encoder_bias", | ||
228 | type=str, | ||
229 | default="none", | ||
230 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | ||
231 | ) | ||
232 | parser.add_argument( | ||
179 | "--find_lr", | 233 | "--find_lr", |
180 | action="store_true", | 234 | action="store_true", |
181 | help="Automatically find a learning rate (no training).", | 235 | help="Automatically find a learning rate (no training).", |
@@ -424,13 +478,30 @@ def main(): | |||
424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 478 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
425 | args.pretrained_model_name_or_path) | 479 | args.pretrained_model_name_or_path) |
426 | 480 | ||
481 | unet_config = LoraConfig( | ||
482 | r=args.lora_r, | ||
483 | lora_alpha=args.lora_alpha, | ||
484 | target_modules=UNET_TARGET_MODULES, | ||
485 | lora_dropout=args.lora_dropout, | ||
486 | bias=args.lora_bias, | ||
487 | ) | ||
488 | unet = LoraModel(unet_config, unet) | ||
489 | |||
490 | text_encoder_config = LoraConfig( | ||
491 | r=args.lora_text_encoder_r, | ||
492 | lora_alpha=args.lora_text_encoder_alpha, | ||
493 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
494 | lora_dropout=args.lora_text_encoder_dropout, | ||
495 | bias=args.lora_text_encoder_bias, | ||
496 | ) | ||
497 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
498 | |||
427 | vae.enable_slicing() | 499 | vae.enable_slicing() |
428 | vae.set_use_memory_efficient_attention_xformers(True) | 500 | vae.set_use_memory_efficient_attention_xformers(True) |
429 | unet.enable_xformers_memory_efficient_attention() | 501 | unet.enable_xformers_memory_efficient_attention() |
430 | 502 | ||
431 | if args.gradient_checkpointing: | 503 | if args.gradient_checkpointing: |
432 | unet.enable_gradient_checkpointing() | 504 | unet.enable_gradient_checkpointing() |
433 | text_encoder.gradient_checkpointing_enable() | ||
434 | 505 | ||
435 | if args.embeddings_dir is not None: | 506 | if args.embeddings_dir is not None: |
436 | embeddings_dir = Path(args.embeddings_dir) | 507 | embeddings_dir = Path(args.embeddings_dir) |
diff --git a/training/functional.py b/training/functional.py index a5b339d..ee73ab2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -34,7 +34,6 @@ def const(result=None): | |||
34 | 34 | ||
35 | @dataclass | 35 | @dataclass |
36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
37 | on_prepare: Callable[[], None] = const() | ||
38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
39 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
@@ -620,10 +619,8 @@ def train( | |||
620 | kwargs.update(extra) | 619 | kwargs.update(extra) |
621 | 620 | ||
622 | vae.to(accelerator.device, dtype=dtype) | 621 | vae.to(accelerator.device, dtype=dtype) |
623 | 622 | vae.requires_grad_(False) | |
624 | for model in (unet, text_encoder, vae): | 623 | vae.eval() |
625 | model.requires_grad_(False) | ||
626 | model.eval() | ||
627 | 624 | ||
628 | callbacks = strategy.callbacks( | 625 | callbacks = strategy.callbacks( |
629 | accelerator=accelerator, | 626 | accelerator=accelerator, |
@@ -636,8 +633,6 @@ def train( | |||
636 | **kwargs, | 633 | **kwargs, |
637 | ) | 634 | ) |
638 | 635 | ||
639 | callbacks.on_prepare() | ||
640 | |||
641 | loss_step_ = partial( | 636 | loss_step_ = partial( |
642 | loss_step, | 637 | loss_step, |
643 | vae, | 638 | vae, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 28fccff..9808027 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks( | |||
74 | power=ema_power, | 74 | power=ema_power, |
75 | max_value=ema_max_decay, | 75 | max_value=ema_max_decay, |
76 | ) | 76 | ) |
77 | ema_unet.to(accelerator.device) | ||
77 | else: | 78 | else: |
78 | ema_unet = None | 79 | ema_unet = None |
79 | 80 | ||
@@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks( | |||
86 | def on_accum_model(): | 87 | def on_accum_model(): |
87 | return unet | 88 | return unet |
88 | 89 | ||
89 | def on_prepare(): | ||
90 | unet.requires_grad_(True) | ||
91 | text_encoder.text_model.encoder.requires_grad_(True) | ||
92 | text_encoder.text_model.final_layer_norm.requires_grad_(True) | ||
93 | |||
94 | if ema_unet is not None: | ||
95 | ema_unet.to(accelerator.device) | ||
96 | |||
97 | @contextmanager | 90 | @contextmanager |
98 | def on_train(epoch: int): | 91 | def on_train(epoch: int): |
99 | tokenizer.train() | 92 | tokenizer.train() |
@@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks( | |||
181 | torch.cuda.empty_cache() | 174 | torch.cuda.empty_cache() |
182 | 175 | ||
183 | return TrainingCallbacks( | 176 | return TrainingCallbacks( |
184 | on_prepare=on_prepare, | ||
185 | on_accum_model=on_accum_model, | 177 | on_accum_model=on_accum_model, |
186 | on_train=on_train, | 178 | on_train=on_train, |
187 | on_eval=on_eval, | 179 | on_eval=on_eval, |
@@ -203,7 +195,12 @@ def dreambooth_prepare( | |||
203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 195 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
204 | **kwargs | 196 | **kwargs |
205 | ): | 197 | ): |
206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
199 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
200 | |||
201 | text_encoder.text_model.embeddings.requires_grad_(False) | ||
202 | |||
203 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
207 | 204 | ||
208 | 205 | ||
209 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1c8fad6..3971eae 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -10,18 +10,12 @@ from torch.utils.data import DataLoader | |||
10 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
11 | from transformers import CLIPTextModel | 11 | from transformers import CLIPTextModel |
12 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 12 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
13 | from peft import LoraConfig, LoraModel, get_peft_model_state_dict | 13 | from peft import get_peft_model_state_dict |
14 | from peft.tuners.lora import mark_only_lora_as_trainable | ||
15 | 14 | ||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
18 | 17 | ||
19 | 18 | ||
20 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
21 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
22 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
23 | |||
24 | |||
25 | def lora_strategy_callbacks( | 19 | def lora_strategy_callbacks( |
26 | accelerator: Accelerator, | 20 | accelerator: Accelerator, |
27 | unet: UNet2DConditionModel, | 21 | unet: UNet2DConditionModel, |
@@ -61,10 +55,6 @@ def lora_strategy_callbacks( | |||
61 | image_size=sample_image_size, | 55 | image_size=sample_image_size, |
62 | ) | 56 | ) |
63 | 57 | ||
64 | def on_prepare(): | ||
65 | mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) | ||
66 | mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) | ||
67 | |||
68 | def on_accum_model(): | 58 | def on_accum_model(): |
69 | return unet | 59 | return unet |
70 | 60 | ||
@@ -93,15 +83,15 @@ def lora_strategy_callbacks( | |||
93 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 83 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
94 | 84 | ||
95 | lora_config = {} | 85 | lora_config = {} |
96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) | 86 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
97 | lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) | 87 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
98 | 88 | ||
99 | text_encoder_state_dict = get_peft_model_state_dict( | 89 | text_encoder_state_dict = get_peft_model_state_dict( |
100 | text_encoder, state_dict=accelerator.get_state_dict(text_encoder) | 90 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
101 | ) | 91 | ) |
102 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 92 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} |
103 | state_dict.update(text_encoder_state_dict) | 93 | state_dict.update(text_encoder_state_dict) |
104 | lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) | 94 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
105 | 95 | ||
106 | accelerator.print(state_dict) | 96 | accelerator.print(state_dict) |
107 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | 97 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") |
@@ -111,11 +101,16 @@ def lora_strategy_callbacks( | |||
111 | 101 | ||
112 | @torch.no_grad() | 102 | @torch.no_grad() |
113 | def on_sample(step): | 103 | def on_sample(step): |
104 | vae_dtype = vae.dtype | ||
105 | vae.to(dtype=text_encoder.dtype) | ||
106 | |||
114 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 107 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
115 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 108 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
116 | 109 | ||
117 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 110 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
118 | 111 | ||
112 | vae.to(dtype=vae_dtype) | ||
113 | |||
119 | del unet_ | 114 | del unet_ |
120 | del text_encoder_ | 115 | del text_encoder_ |
121 | 116 | ||
@@ -123,7 +118,6 @@ def lora_strategy_callbacks( | |||
123 | torch.cuda.empty_cache() | 118 | torch.cuda.empty_cache() |
124 | 119 | ||
125 | return TrainingCallbacks( | 120 | return TrainingCallbacks( |
126 | on_prepare=on_prepare, | ||
127 | on_accum_model=on_accum_model, | 121 | on_accum_model=on_accum_model, |
128 | on_train=on_train, | 122 | on_train=on_train, |
129 | on_eval=on_eval, | 123 | on_eval=on_eval, |
@@ -147,28 +141,7 @@ def lora_prepare( | |||
147 | lora_bias: str = "none", | 141 | lora_bias: str = "none", |
148 | **kwargs | 142 | **kwargs |
149 | ): | 143 | ): |
150 | unet_config = LoraConfig( | 144 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) |
151 | r=lora_rank, | ||
152 | lora_alpha=lora_alpha, | ||
153 | target_modules=UNET_TARGET_MODULES, | ||
154 | lora_dropout=lora_dropout, | ||
155 | bias=lora_bias, | ||
156 | ) | ||
157 | unet = LoraModel(unet_config, unet) | ||
158 | |||
159 | text_encoder_config = LoraConfig( | ||
160 | r=lora_rank, | ||
161 | lora_alpha=lora_alpha, | ||
162 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
163 | lora_dropout=lora_dropout, | ||
164 | bias=lora_bias, | ||
165 | ) | ||
166 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
167 | |||
168 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
169 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
170 | |||
171 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
172 | 145 | ||
173 | 146 | ||
174 | lora_strategy = TrainingStrategy( | 147 | lora_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 2038e34..10bc6d7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks( | |||
78 | power=ema_power, | 78 | power=ema_power, |
79 | max_value=ema_max_decay, | 79 | max_value=ema_max_decay, |
80 | ) | 80 | ) |
81 | ema_embeddings.to(accelerator.device) | ||
81 | else: | 82 | else: |
82 | ema_embeddings = None | 83 | ema_embeddings = None |
83 | 84 | ||
@@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks( | |||
92 | def on_accum_model(): | 93 | def on_accum_model(): |
93 | return text_encoder.text_model.embeddings.temp_token_embedding | 94 | return text_encoder.text_model.embeddings.temp_token_embedding |
94 | 95 | ||
95 | def on_prepare(): | ||
96 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
97 | |||
98 | if ema_embeddings is not None: | ||
99 | ema_embeddings.to(accelerator.device) | ||
100 | |||
101 | if gradient_checkpointing: | ||
102 | unet.train() | ||
103 | |||
104 | @contextmanager | 96 | @contextmanager |
105 | def on_train(epoch: int): | 97 | def on_train(epoch: int): |
106 | tokenizer.train() | 98 | tokenizer.train() |
@@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks( | |||
177 | torch.cuda.empty_cache() | 169 | torch.cuda.empty_cache() |
178 | 170 | ||
179 | return TrainingCallbacks( | 171 | return TrainingCallbacks( |
180 | on_prepare=on_prepare, | ||
181 | on_accum_model=on_accum_model, | 172 | on_accum_model=on_accum_model, |
182 | on_train=on_train, | 173 | on_train=on_train, |
183 | on_eval=on_eval, | 174 | on_eval=on_eval, |
@@ -197,6 +188,7 @@ def textual_inversion_prepare( | |||
197 | train_dataloader: DataLoader, | 188 | train_dataloader: DataLoader, |
198 | val_dataloader: Optional[DataLoader], | 189 | val_dataloader: Optional[DataLoader], |
199 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 190 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
191 | gradient_checkpointing: bool = False, | ||
200 | **kwargs | 192 | **kwargs |
201 | ): | 193 | ): |
202 | weight_dtype = torch.float32 | 194 | weight_dtype = torch.float32 |
@@ -207,7 +199,17 @@ def textual_inversion_prepare( | |||
207 | 199 | ||
208 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 200 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
209 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 201 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
202 | |||
210 | unet.to(accelerator.device, dtype=weight_dtype) | 203 | unet.to(accelerator.device, dtype=weight_dtype) |
204 | unet.requires_grad_(False) | ||
205 | unet.eval() | ||
206 | if gradient_checkpointing: | ||
207 | unet.train() | ||
208 | |||
209 | text_encoder.text_model.encoder.requires_grad_(False) | ||
210 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | ||
211 | text_encoder.eval() | ||
212 | |||
211 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 213 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
212 | 214 | ||
213 | 215 | ||