summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py2
-rw-r--r--models/lora.py8
2 files changed, 5 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 60c1b20..840f8ae 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -2,7 +2,6 @@ from typing import Union, Optional
2from pathlib import Path 2from pathlib import Path
3 3
4import torch 4import torch
5import torch.nn as nn
6 5
7from safetensors import safe_open 6from safetensors import safe_open
8from safetensors.torch import save_file 7from safetensors.torch import save_file
@@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
64 63
65 token_ids = torch.tensor(token_ids, dtype=torch.long) 64 token_ids = torch.tensor(token_ids, dtype=torch.long)
66 65
66 self.token_embedding.mark_trainable(token_ids)
67 self.token_embedding.weight.data[token_ids] = initializer 67 self.token_embedding.weight.data[token_ids] = initializer
68 68
69 def load_embed(self, input_ids: list[int], filename: Path): 69 def load_embed(self, input_ids: list[int], filename: Path):
diff --git a/models/lora.py b/models/lora.py
index c0f74a6..98d4d2c 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
83 if new_ids.shape[0] == 0: 83 if new_ids.shape[0] == 0:
84 return 84 return
85 85
86 n = self.trainable_ids.shape[0] 86 n1 = self.lora_A.shape[1]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
88 89
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) 90 lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A 91 self.lora_A = lora_A
92 92
93 def reset_parameters(self): 93 def reset_parameters(self):