summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 09:09:46 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 09:09:46 +0200
commitd952d467d31786f4a85cc4cb009934cd4ebbba71 (patch)
tree68c3f597a86ef3b98734d80cc783aa1f42fe1a41 /training
parentUpdate (diff)
downloadtextual-inversion-diff-d952d467d31786f4a85cc4cb009934cd4ebbba71.tar.gz
textual-inversion-diff-d952d467d31786f4a85cc4cb009934cd4ebbba71.tar.bz2
textual-inversion-diff-d952d467d31786f4a85cc4cb009934cd4ebbba71.zip
Update
Diffstat (limited to 'training')
-rw-r--r--training/strategy/lora.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 209785a..d51a2f3 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
14from peft import get_peft_model_state_dict 14from peft import get_peft_model_state_dict
15from safetensors.torch import save_file 15from safetensors.torch import save_file
16 16
17from slugify import slugify
18
17from models.clip.tokenizer import MultiCLIPTokenizer 19from models.clip.tokenizer import MultiCLIPTokenizer
18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 20from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 21
@@ -30,6 +32,11 @@ def lora_strategy_callbacks(
30 sample_output_dir: Path, 32 sample_output_dir: Path,
31 checkpoint_output_dir: Path, 33 checkpoint_output_dir: Path,
32 seed: int, 34 seed: int,
35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]],
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay: float = 1e-2,
33 max_grad_norm: float = 1.0, 40 max_grad_norm: float = 1.0,
34 sample_batch_size: int = 1, 41 sample_batch_size: int = 1,
35 sample_num_batches: int = 1, 42 sample_num_batches: int = 1,
@@ -77,6 +84,22 @@ def lora_strategy_callbacks(
77 max_grad_norm 84 max_grad_norm
78 ) 85 )
79 86
87 if use_emb_decay:
88 return torch.stack([
89 p
90 for p in text_encoder.text_model.embeddings.token_override_embedding.params
91 if p.grad is not None
92 ])
93
94 @torch.no_grad()
95 def on_after_optimize(w, lr: float):
96 if use_emb_decay:
97 lambda_ = emb_decay * lr
98
99 if lambda_ != 0:
100 norm = w[:, :].norm(dim=-1, keepdim=True)
101 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
102
80 @torch.no_grad() 103 @torch.no_grad()
81 def on_checkpoint(step, postfix): 104 def on_checkpoint(step, postfix):
82 if postfix != "end": 105 if postfix != "end":
@@ -87,6 +110,12 @@ def lora_strategy_callbacks(
87 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 110 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
88 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 111 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
89 112
113 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
114 text_encoder_.text_model.embeddings.save_embed(
115 ids,
116 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
117 )
118
90 lora_config = {} 119 lora_config = {}
91 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 120 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
92 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 121 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
@@ -126,6 +155,7 @@ def lora_strategy_callbacks(
126 on_train=on_train, 155 on_train=on_train,
127 on_eval=on_eval, 156 on_eval=on_eval,
128 on_before_optimize=on_before_optimize, 157 on_before_optimize=on_before_optimize,
158 on_after_optimize=on_after_optimize,
129 on_checkpoint=on_checkpoint, 159 on_checkpoint=on_checkpoint,
130 on_sample=on_sample, 160 on_sample=on_sample,
131 ) 161 )
@@ -141,7 +171,12 @@ def lora_prepare(
141 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 171 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
142 **kwargs 172 **kwargs
143): 173):
144 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) 174 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
175 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
176
177 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True)
178
179 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
145 180
146 181
147lora_strategy = TrainingStrategy( 182lora_strategy = TrainingStrategy(