diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 209785a..d51a2f3 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
14 | from peft import get_peft_model_state_dict | 14 | from peft import get_peft_model_state_dict |
15 | from safetensors.torch import save_file | 15 | from safetensors.torch import save_file |
16 | 16 | ||
17 | from slugify import slugify | ||
18 | |||
17 | from models.clip.tokenizer import MultiCLIPTokenizer | 19 | from models.clip.tokenizer import MultiCLIPTokenizer |
18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 20 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 21 | ||
@@ -30,6 +32,11 @@ def lora_strategy_callbacks( | |||
30 | sample_output_dir: Path, | 32 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | 33 | checkpoint_output_dir: Path, |
32 | seed: int, | 34 | seed: int, |
35 | placeholder_tokens: list[str], | ||
36 | placeholder_token_ids: list[list[int]], | ||
37 | use_emb_decay: bool = False, | ||
38 | emb_decay_target: float = 0.4, | ||
39 | emb_decay: float = 1e-2, | ||
33 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, |
34 | sample_batch_size: int = 1, | 41 | sample_batch_size: int = 1, |
35 | sample_num_batches: int = 1, | 42 | sample_num_batches: int = 1, |
@@ -77,6 +84,22 @@ def lora_strategy_callbacks( | |||
77 | max_grad_norm | 84 | max_grad_norm |
78 | ) | 85 | ) |
79 | 86 | ||
87 | if use_emb_decay: | ||
88 | return torch.stack([ | ||
89 | p | ||
90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
91 | if p.grad is not None | ||
92 | ]) | ||
93 | |||
94 | @torch.no_grad() | ||
95 | def on_after_optimize(w, lr: float): | ||
96 | if use_emb_decay: | ||
97 | lambda_ = emb_decay * lr | ||
98 | |||
99 | if lambda_ != 0: | ||
100 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
101 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
102 | |||
80 | @torch.no_grad() | 103 | @torch.no_grad() |
81 | def on_checkpoint(step, postfix): | 104 | def on_checkpoint(step, postfix): |
82 | if postfix != "end": | 105 | if postfix != "end": |
@@ -87,6 +110,12 @@ def lora_strategy_callbacks( | |||
87 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 110 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
88 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 111 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
89 | 112 | ||
113 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | ||
114 | text_encoder_.text_model.embeddings.save_embed( | ||
115 | ids, | ||
116 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | ||
117 | ) | ||
118 | |||
90 | lora_config = {} | 119 | lora_config = {} |
91 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 120 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
92 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 121 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
@@ -126,6 +155,7 @@ def lora_strategy_callbacks( | |||
126 | on_train=on_train, | 155 | on_train=on_train, |
127 | on_eval=on_eval, | 156 | on_eval=on_eval, |
128 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |
158 | on_after_optimize=on_after_optimize, | ||
129 | on_checkpoint=on_checkpoint, | 159 | on_checkpoint=on_checkpoint, |
130 | on_sample=on_sample, | 160 | on_sample=on_sample, |
131 | ) | 161 | ) |
@@ -141,7 +171,12 @@ def lora_prepare( | |||
141 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 171 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
142 | **kwargs | 172 | **kwargs |
143 | ): | 173 | ): |
144 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 174 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
175 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
176 | |||
177 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | ||
178 | |||
179 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
145 | 180 | ||
146 | 181 | ||
147 | lora_strategy = TrainingStrategy( | 182 | lora_strategy = TrainingStrategy( |