From d488f66c78e444d03c4ef8a957b82f8b239379d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:31:24 +0200 Subject: Fix --- models/clip/embeddings.py | 2 +- models/lora.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 60c1b20..840f8ae 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -2,7 +2,6 @@ from typing import Union, Optional from pathlib import Path import torch -import torch.nn as nn from safetensors import safe_open from safetensors.torch import save_file @@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) + self.token_embedding.mark_trainable(token_ids) self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): diff --git a/models/lora.py b/models/lora.py index c0f74a6..98d4d2c 100644 --- a/models/lora.py +++ b/models/lora.py @@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if new_ids.shape[0] == 0: return - n = self.trainable_ids.shape[0] - self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) + n1 = self.lora_A.shape[1] + n2 = n1 + new_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n1, n2) - lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) - lora_A.data[:n] = self.lora_A.data + lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) self.lora_A = lora_A def reset_parameters(self): -- cgit v1.2.3-70-g09d2