diff options
| -rw-r--r-- | models/clip/embeddings.py | 2 | ||||
| -rw-r--r-- | models/lora.py | 8 | ||||
| -rw-r--r-- | training/strategy/ti.py | 19 | 
3 files changed, 5 insertions, 24 deletions
| diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 60c1b20..840f8ae 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -2,7 +2,6 @@ from typing import Union, Optional | |||
| 2 | from pathlib import Path | 2 | from pathlib import Path | 
| 3 | 3 | ||
| 4 | import torch | 4 | import torch | 
| 5 | import torch.nn as nn | ||
| 6 | 5 | ||
| 7 | from safetensors import safe_open | 6 | from safetensors import safe_open | 
| 8 | from safetensors.torch import save_file | 7 | from safetensors.torch import save_file | 
| @@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 64 | 63 | ||
| 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 64 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 
| 66 | 65 | ||
| 66 | self.token_embedding.mark_trainable(token_ids) | ||
| 67 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer | 
| 68 | 68 | ||
| 69 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): | 
| diff --git a/models/lora.py b/models/lora.py index c0f74a6..98d4d2c 100644 --- a/models/lora.py +++ b/models/lora.py | |||
| @@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 83 | if new_ids.shape[0] == 0: | 83 | if new_ids.shape[0] == 0: | 
| 84 | return | 84 | return | 
| 85 | 85 | ||
| 86 | n = self.trainable_ids.shape[0] | 86 | n1 = self.lora_A.shape[1] | 
| 87 | self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) | 87 | n2 = n1 + new_ids.shape[0] | 
| 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
| 88 | 89 | ||
| 89 | lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) | 90 | lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) | 
| 90 | lora_A.data[:n] = self.lora_A.data | ||
| 91 | self.lora_A = lora_A | 91 | self.lora_A = lora_A | 
| 92 | 92 | ||
| 93 | def reset_parameters(self): | 93 | def reset_parameters(self): | 
| diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 49236c6..f0b84b5 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,28 +104,10 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield | 
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() | 
| 107 | def on_before_optimize(epoch: int): | ||
| 108 | if use_emb_decay: | ||
| 109 | params = [ | ||
| 110 | p | ||
| 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 112 | if p.grad is not None | ||
| 113 | ] | ||
| 114 | return torch.stack(params) if len(params) != 0 else None | ||
| 115 | |||
| 116 | @torch.no_grad() | ||
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 107 | def on_after_optimize(w, lrs: dict[str, float]): | 
| 118 | if ema_embeddings is not None: | 108 | if ema_embeddings is not None: | 
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 109 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 
| 120 | 110 | ||
| 121 | if use_emb_decay and w is not None: | ||
| 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | ||
| 123 | lambda_ = emb_decay * lr | ||
| 124 | |||
| 125 | if lambda_ != 0: | ||
| 126 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 128 | |||
| 129 | def on_log(): | 111 | def on_log(): | 
| 130 | if ema_embeddings is not None: | 112 | if ema_embeddings is not None: | 
| 131 | return {"ema_decay": ema_embeddings.decay} | 113 | return {"ema_decay": ema_embeddings.decay} | 
| @@ -166,7 +148,6 @@ def textual_inversion_strategy_callbacks( | |||
| 166 | return TrainingCallbacks( | 148 | return TrainingCallbacks( | 
| 167 | on_train=on_train, | 149 | on_train=on_train, | 
| 168 | on_eval=on_eval, | 150 | on_eval=on_eval, | 
| 169 | on_before_optimize=on_before_optimize, | ||
| 170 | on_after_optimize=on_after_optimize, | 151 | on_after_optimize=on_after_optimize, | 
| 171 | on_log=on_log, | 152 | on_log=on_log, | 
| 172 | on_checkpoint=on_checkpoint, | 153 | on_checkpoint=on_checkpoint, | 
