diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/strategy/lora.py | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 209785a..d51a2f3 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 14 | from peft import get_peft_model_state_dict | 14 | from peft import get_peft_model_state_dict |
| 15 | from safetensors.torch import save_file | 15 | from safetensors.torch import save_file |
| 16 | 16 | ||
| 17 | from slugify import slugify | ||
| 18 | |||
| 17 | from models.clip.tokenizer import MultiCLIPTokenizer | 19 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 20 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 21 | ||
| @@ -30,6 +32,11 @@ def lora_strategy_callbacks( | |||
| 30 | sample_output_dir: Path, | 32 | sample_output_dir: Path, |
| 31 | checkpoint_output_dir: Path, | 33 | checkpoint_output_dir: Path, |
| 32 | seed: int, | 34 | seed: int, |
| 35 | placeholder_tokens: list[str], | ||
| 36 | placeholder_token_ids: list[list[int]], | ||
| 37 | use_emb_decay: bool = False, | ||
| 38 | emb_decay_target: float = 0.4, | ||
| 39 | emb_decay: float = 1e-2, | ||
| 33 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, |
| 34 | sample_batch_size: int = 1, | 41 | sample_batch_size: int = 1, |
| 35 | sample_num_batches: int = 1, | 42 | sample_num_batches: int = 1, |
| @@ -77,6 +84,22 @@ def lora_strategy_callbacks( | |||
| 77 | max_grad_norm | 84 | max_grad_norm |
| 78 | ) | 85 | ) |
| 79 | 86 | ||
| 87 | if use_emb_decay: | ||
| 88 | return torch.stack([ | ||
| 89 | p | ||
| 90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
| 91 | if p.grad is not None | ||
| 92 | ]) | ||
| 93 | |||
| 94 | @torch.no_grad() | ||
| 95 | def on_after_optimize(w, lr: float): | ||
| 96 | if use_emb_decay: | ||
| 97 | lambda_ = emb_decay * lr | ||
| 98 | |||
| 99 | if lambda_ != 0: | ||
| 100 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 101 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 102 | |||
| 80 | @torch.no_grad() | 103 | @torch.no_grad() |
| 81 | def on_checkpoint(step, postfix): | 104 | def on_checkpoint(step, postfix): |
| 82 | if postfix != "end": | 105 | if postfix != "end": |
| @@ -87,6 +110,12 @@ def lora_strategy_callbacks( | |||
| 87 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 110 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 88 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 111 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 89 | 112 | ||
| 113 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | ||
| 114 | text_encoder_.text_model.embeddings.save_embed( | ||
| 115 | ids, | ||
| 116 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | ||
| 117 | ) | ||
| 118 | |||
| 90 | lora_config = {} | 119 | lora_config = {} |
| 91 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 120 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
| 92 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 121 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
| @@ -126,6 +155,7 @@ def lora_strategy_callbacks( | |||
| 126 | on_train=on_train, | 155 | on_train=on_train, |
| 127 | on_eval=on_eval, | 156 | on_eval=on_eval, |
| 128 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |
| 158 | on_after_optimize=on_after_optimize, | ||
| 129 | on_checkpoint=on_checkpoint, | 159 | on_checkpoint=on_checkpoint, |
| 130 | on_sample=on_sample, | 160 | on_sample=on_sample, |
| 131 | ) | 161 | ) |
| @@ -141,7 +171,12 @@ def lora_prepare( | |||
| 141 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 171 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 142 | **kwargs | 172 | **kwargs |
| 143 | ): | 173 | ): |
| 144 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 174 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 175 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
| 176 | |||
| 177 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | ||
| 178 | |||
| 179 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
| 145 | 180 | ||
| 146 | 181 | ||
| 147 | lora_strategy = TrainingStrategy( | 182 | lora_strategy = TrainingStrategy( |
