summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
commit01eee0cb24f52ca78761b78917959e1c247eae94 (patch)
tree914c0d3f5b888a4c344b30a861639c8e3d5259dd /models
parentUpdate (diff)
downloadtextual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 4166dc6..9abd1bb 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(
56 self,
57 token_ids: Union[int, list[int]],
58 initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None,
59 initializer_noise: float = 0.0,
60 ):
56 if isinstance(token_ids, int): 61 if isinstance(token_ids, int):
57 token_ids = [token_ids] 62 token_ids = [token_ids]
58 63
@@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
73 dtype=self.temp_token_embedding.weight.dtype, 78 dtype=self.temp_token_embedding.weight.dtype,
74 ) 79 )
75 80
81 if initializer_noise != 0:
82 initializer += torch.randn_like(initializer) * initializer_noise
83
76 token_ids = torch.tensor(token_ids, dtype=torch.long) 84 token_ids = torch.tensor(token_ids, dtype=torch.long)
77 85
78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 86 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])