diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 2 | ||||
-rw-r--r-- | models/lora.py | 8 |
2 files changed, 5 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 60c1b20..840f8ae 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -2,7 +2,6 @@ from typing import Union, Optional | |||
2 | from pathlib import Path | 2 | from pathlib import Path |
3 | 3 | ||
4 | import torch | 4 | import torch |
5 | import torch.nn as nn | ||
6 | 5 | ||
7 | from safetensors import safe_open | 6 | from safetensors import safe_open |
8 | from safetensors.torch import save_file | 7 | from safetensors.torch import save_file |
@@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
64 | 63 | ||
65 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 64 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
66 | 65 | ||
66 | self.token_embedding.mark_trainable(token_ids) | ||
67 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
68 | 68 | ||
69 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
diff --git a/models/lora.py b/models/lora.py index c0f74a6..98d4d2c 100644 --- a/models/lora.py +++ b/models/lora.py | |||
@@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
83 | if new_ids.shape[0] == 0: | 83 | if new_ids.shape[0] == 0: |
84 | return | 84 | return |
85 | 85 | ||
86 | n = self.trainable_ids.shape[0] | 86 | n1 = self.lora_A.shape[1] |
87 | self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) | 87 | n2 = n1 + new_ids.shape[0] |
88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
88 | 89 | ||
89 | lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) | 90 | lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) |
90 | lora_A.data[:n] = self.lora_A.data | ||
91 | self.lora_A = lora_A | 91 | self.lora_A = lora_A |
92 | 92 | ||
93 | def reset_parameters(self): | 93 | def reset_parameters(self): |