From fa42b1656b6d55f2e405ca540519b1ac64df9411 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:47:01 +0200 Subject: Fix --- models/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models') diff --git a/models/lora.py b/models/lora.py index 98d4d2c..89c4b2e 100644 --- a/models/lora.py +++ b/models/lora.py @@ -78,7 +78,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): def mark_trainable(self, input_ids): trainable_ids = self.trainable_ids[input_ids] - new_ids = trainable_ids[trainable_ids == -1] + new_ids = input_ids[trainable_ids == -1] if new_ids.shape[0] == 0: return -- cgit v1.2.3-70-g09d2