diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 21:00:29 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 21:00:29 +0200 |
| commit | 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch) | |
| tree | b0fcf8ad1d26c40d784ddc154622f6d01ecac082 /models | |
| parent | New loss scaling (diff) | |
| download | textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2 textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip | |
Update
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 4 | ||||
| -rw-r--r-- | models/sparse.py | 11 |
2 files changed, 10 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 79 | def save_embed(self, input_ids: list[int], filename: Path): | 79 | def save_embed(self, input_ids: list[int], filename: Path): |
| 80 | save_file({"embed": self.get_embed(input_ids)}, filename) | 80 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 81 | 81 | ||
| 82 | def persist(self): | 82 | def persist(self, clear=False): |
| 83 | self.token_embedding.persist() | 83 | self.token_embedding.persist(clear) |
| 84 | 84 | ||
| 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 86 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
| 89 | 89 | ||
| 90 | return weights | 90 | return weights |
| 91 | 91 | ||
| 92 | def persist(self): | 92 | def persist(self, clear=False): |
| 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 94 | self.trainable_ids[:] = -1 | 94 | |
| 95 | self.trainable = nn.ParameterList() | 95 | if clear: |
| 96 | self.trainable_ids[:] = -1 | ||
| 97 | self.trainable = nn.ParameterList() | ||
| 98 | else: | ||
| 99 | for param in self.trainable: | ||
| 100 | param.zero_() | ||
| 96 | 101 | ||
| 97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
| 98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |
