From d488f66c78e444d03c4ef8a957b82f8b239379d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:31:24 +0200 Subject: Fix --- models/clip/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 60c1b20..840f8ae 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -2,7 +2,6 @@ from typing import Union, Optional from pathlib import Path import torch -import torch.nn as nn from safetensors import safe_open from safetensors.torch import save_file @@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) + self.token_embedding.mark_trainable(token_ids) self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): -- cgit v1.2.3-70-g09d2