diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 60c1b20..840f8ae 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -2,7 +2,6 @@ from typing import Union, Optional | |||
| 2 | from pathlib import Path | 2 | from pathlib import Path |
| 3 | 3 | ||
| 4 | import torch | 4 | import torch |
| 5 | import torch.nn as nn | ||
| 6 | 5 | ||
| 7 | from safetensors import safe_open | 6 | from safetensors import safe_open |
| 8 | from safetensors.torch import save_file | 7 | from safetensors.torch import save_file |
| @@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 64 | 63 | ||
| 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 64 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 66 | 65 | ||
| 66 | self.token_embedding.mark_trainable(token_ids) | ||
| 67 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
| 68 | 68 | ||
| 69 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
