summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 10:37:04 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 10:37:04 +0200
commit15d1a15d1010509c8a2a6dd1ffa47b81e7bc0b78 (patch)
tree70a43a29a3807a380242327dc00f16c6e712dd45 /models
parentUpdate (diff)
downloadtextual-inversion-diff-15d1a15d1010509c8a2a6dd1ffa47b81e7bc0b78.tar.gz
textual-inversion-diff-15d1a15d1010509c8a2a6dd1ffa47b81e7bc0b78.tar.bz2
textual-inversion-diff-15d1a15d1010509c8a2a6dd1ffa47b81e7bc0b78.zip
Fix
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py2
-rw-r--r--models/lora.py42
2 files changed, 21 insertions, 23 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 4444cf9..d02ccc3 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -64,7 +64,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
64 token_ids = torch.tensor(token_ids, dtype=torch.long) 64 token_ids = torch.tensor(token_ids, dtype=torch.long)
65 65
66 self.token_embedding.mark_trainable(token_ids) 66 self.token_embedding.mark_trainable(token_ids)
67 self.token_embedding.weight.data[token_ids] = initializer 67 self.token_embedding.weight[token_ids].data = initializer
68 68
69 def load_embed(self, input_ids: list[int], filename: Path): 69 def load_embed(self, input_ids: list[int], filename: Path):
70 with safe_open(filename, framework="pt", device="cpu") as file: 70 with safe_open(filename, framework="pt", device="cpu") as file:
diff --git a/models/lora.py b/models/lora.py
index b7fa58f..a8197a5 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -42,7 +42,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights 42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 ) 43 )
44 44
45 self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) 45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long))
46 self.trainable_ids -= 1 46 self.trainable_ids -= 1
47 47
48 if r > 0: 48 if r > 0:
@@ -76,7 +76,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
76 76
77 return new_emb 77 return new_emb
78 78
79 def mark_trainable(self, input_ids): 79 def mark_trainable(self, input_ids: torch.LongTensor):
80 trainable_ids = self.trainable_ids[input_ids] 80 trainable_ids = self.trainable_ids[input_ids]
81 new_ids = input_ids[trainable_ids == -1] 81 new_ids = input_ids[trainable_ids == -1]
82 82
@@ -87,15 +87,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
87 n2 = n1 + new_ids.shape[0] 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2) 88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
89 for _ in new_ids: 89 for _ in new_ids:
90 self.lora_A.append(self.weight.new_zeros(self.r)) 90 self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True))
91
92 def persist(self):
93 if self.r > 0:
94 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
95 if weights is not None:
96 self.weight[mask].data += weights
97 self.trainable_ids[:] = -1
98 self.lora_A = nn.ParameterList()
99 91
100 def get_weights(self, input_ids: torch.Tensor): 92 def get_weights(self, input_ids: torch.Tensor):
101 trainable_ids = self.trainable_ids[input_ids] 93 trainable_ids = self.trainable_ids[input_ids]
@@ -104,16 +96,25 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
104 96
105 elems = [self.lora_A[id] for id in trainable_ids] 97 elems = [self.lora_A[id] for id in trainable_ids]
106 98
107 if len(elems) == 0: 99 if len(elems) != 0:
108 return None, mask 100 weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
109 101 else:
110 weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling 102 weights = self.weight.new_zeros(self.embedding_dim)
111 103
112 return weights, mask 104 return weights, mask
113 105
106 def persist(self):
107 if self.r > 0:
108 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
109 if weights is not None:
110 self.weight[mask].data += weights
111 self.trainable_ids[:] = -1
112 self.lora_A = nn.ParameterList()
113
114 def reset_parameters(self): 114 def reset_parameters(self):
115 nn.Embedding.reset_parameters(self) 115 nn.Embedding.reset_parameters(self)
116 if hasattr(self, 'lora_A'): 116 if hasattr(self, 'lora_A'):
117 self.trainable_ids[:] = -1
117 self.lora_A = nn.ParameterList() 118 self.lora_A = nn.ParameterList()
118 nn.init.zeros_(self.lora_B.weight) 119 nn.init.zeros_(self.lora_B.weight)
119 120
@@ -122,8 +123,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
122 if self.merge_weights and self.merged: 123 if self.merge_weights and self.merged:
123 if self.r > 0: 124 if self.r > 0:
124 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 125 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
125 if weights is not None: 126 self.weight[mask].data -= weights
126 self.weight[mask].data -= weights
127 self.merged = False 127 self.merged = False
128 128
129 def eval(self): 129 def eval(self):
@@ -131,16 +131,14 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
131 if self.merge_weights and not self.merged: 131 if self.merge_weights and not self.merged:
132 if self.r > 0: 132 if self.r > 0:
133 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 133 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
134 if weights is not None: 134 self.weight[mask].data += weights
135 self.weight[mask].data += weights
136 self.merged = True 135 self.merged = True
137 136
138 def forward(self, input_ids: torch.Tensor): 137 def forward(self, input_ids: torch.LongTensor):
139 result = nn.Embedding.forward(self, input_ids) 138 result = nn.Embedding.forward(self, input_ids)
140 139
141 if self.r > 0 and not self.merged: 140 if self.r > 0 and not self.merged:
142 weights, mask = self.get_weights(input_ids) 141 weights, mask = self.get_weights(input_ids)
143 if weights is not None: 142 result[mask] += weights
144 result[mask] += weights
145 143
146 return result 144 return result