summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-15 13:31:24 +0200
committerVolpeon <git@volpeon.ink>2023-04-15 13:31:24 +0200
commitd488f66c78e444d03c4ef8a957b82f8b239379d0 (patch)
tree864b2fe8d03b0cdfc3437622a0dcd5a1ede60e16
parentTI via LoRA (diff)
downloadtextual-inversion-diff-d488f66c78e444d03c4ef8a957b82f8b239379d0.tar.gz
textual-inversion-diff-d488f66c78e444d03c4ef8a957b82f8b239379d0.tar.bz2
textual-inversion-diff-d488f66c78e444d03c4ef8a957b82f8b239379d0.zip
Fix
-rw-r--r--models/clip/embeddings.py2
-rw-r--r--models/lora.py8
-rw-r--r--training/strategy/ti.py19
3 files changed, 5 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 60c1b20..840f8ae 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -2,7 +2,6 @@ from typing import Union, Optional
2from pathlib import Path 2from pathlib import Path
3 3
4import torch 4import torch
5import torch.nn as nn
6 5
7from safetensors import safe_open 6from safetensors import safe_open
8from safetensors.torch import save_file 7from safetensors.torch import save_file
@@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
64 63
65 token_ids = torch.tensor(token_ids, dtype=torch.long) 64 token_ids = torch.tensor(token_ids, dtype=torch.long)
66 65
66 self.token_embedding.mark_trainable(token_ids)
67 self.token_embedding.weight.data[token_ids] = initializer 67 self.token_embedding.weight.data[token_ids] = initializer
68 68
69 def load_embed(self, input_ids: list[int], filename: Path): 69 def load_embed(self, input_ids: list[int], filename: Path):
diff --git a/models/lora.py b/models/lora.py
index c0f74a6..98d4d2c 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
83 if new_ids.shape[0] == 0: 83 if new_ids.shape[0] == 0:
84 return 84 return
85 85
86 n = self.trainable_ids.shape[0] 86 n1 = self.lora_A.shape[1]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
88 89
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) 90 lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A 91 self.lora_A = lora_A
92 92
93 def reset_parameters(self): 93 def reset_parameters(self):
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 49236c6..f0b84b5 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,28 +104,10 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int):
108 if use_emb_decay:
109 params = [
110 p
111 for p in text_encoder.text_model.embeddings.token_embedding.parameters()
112 if p.grad is not None
113 ]
114 return torch.stack(params) if len(params) != 0 else None
115
116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 107 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 108 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 109 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters())
120 110
121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
123 lambda_ = emb_decay * lr
124
125 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
128
129 def on_log(): 111 def on_log():
130 if ema_embeddings is not None: 112 if ema_embeddings is not None:
131 return {"ema_decay": ema_embeddings.decay} 113 return {"ema_decay": ema_embeddings.decay}
@@ -166,7 +148,6 @@ def textual_inversion_strategy_callbacks(
166 return TrainingCallbacks( 148 return TrainingCallbacks(
167 on_train=on_train, 149 on_train=on_train,
168 on_eval=on_eval, 150 on_eval=on_eval,
169 on_before_optimize=on_before_optimize,
170 on_after_optimize=on_after_optimize, 151 on_after_optimize=on_after_optimize,
171 on_log=on_log, 152 on_log=on_log,
172 on_checkpoint=on_checkpoint, 153 on_checkpoint=on_checkpoint,