From d952d467d31786f4a85cc4cb009934cd4ebbba71 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 09:09:46 +0200 Subject: Update --- training/strategy/lora.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) (limited to 'training/strategy') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 209785a..d51a2f3 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from peft import get_peft_model_state_dict from safetensors.torch import save_file +from slugify import slugify + from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -30,6 +32,11 @@ def lora_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, @@ -77,6 +84,22 @@ def lora_strategy_callbacks( max_grad_norm ) + if use_emb_decay: + return torch.stack([ + p + for p in text_encoder.text_model.embeddings.token_override_embedding.params + if p.grad is not None + ]) + + @torch.no_grad() + def on_after_optimize(w, lr: float): + if use_emb_decay: + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": @@ -87,6 +110,12 @@ def lora_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + text_encoder_.text_model.embeddings.save_embed( + ids, + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + ) + lora_config = {} state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) @@ -126,6 +155,7 @@ def lora_strategy_callbacks( on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) @@ -141,7 +171,12 @@ def lora_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + + text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) + + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} lora_strategy = TrainingStrategy( -- cgit v1.2.3-70-g09d2