diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /models/clip/embeddings.py | |
parent | Update (diff) | |
download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip |
Various cleanups
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r-- | models/clip/embeddings.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1280ebd..fb639f1 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 54 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
56 | init_ratio = 1.0 | ||
57 | |||
56 | if isinstance(token_ids, int): | 58 | if isinstance(token_ids, int): |
57 | token_ids = [token_ids] | 59 | token_ids = [token_ids] |
58 | 60 | ||
@@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
63 | initializer = [initializer] | 65 | initializer = [initializer] |
64 | 66 | ||
65 | if isinstance(initializer, list): | 67 | if isinstance(initializer, list): |
68 | init_ratio = len(initializer) / len(token_ids) | ||
66 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 69 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
67 | 70 | ||
68 | with torch.no_grad(): | 71 | with torch.no_grad(): |
@@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
76 | dtype=self.temp_token_embedding.weight.dtype, | 79 | dtype=self.temp_token_embedding.weight.dtype, |
77 | ) | 80 | ) |
78 | 81 | ||
82 | return init_ratio | ||
83 | |||
79 | def load_embed(self, input_ids: list[int], filename: Path): | 84 | def load_embed(self, input_ids: list[int], filename: Path): |
80 | with safe_open(filename, framework="pt", device="cpu") as file: | 85 | with safe_open(filename, framework="pt", device="cpu") as file: |
81 | self.add_embed(input_ids, file.get_tensor("embed")) | 86 | self.add_embed(input_ids, file.get_tensor("embed")) |