summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 13:09:04 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 13:09:04 +0100
commit8c068963d4b67c6b894e720288e5863dade8d6e6 (patch)
tree823bf9852244e5adfe6a4f6fe5fcd87e8441e685
parentAdded multi-vector embeddings (diff)
downloadtextual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.tar.gz
textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.tar.bz2
textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.zip
Fixes
-rw-r--r--models/clip/embeddings.py2
-rw-r--r--train_ti.py3
2 files changed, 2 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 7d63ffb..f82873e 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -74,7 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 74
75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
76 if isinstance(input_ids, list): 76 if isinstance(input_ids, list):
77 input_ids = torch.tensor(input_ids) 77 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device)
78 78
79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) 79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device))
80 80
diff --git a/train_ti.py b/train_ti.py
index 69d15ea..3a5cfed 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -439,11 +439,10 @@ class Checkpointer(CheckpointerBase):
439 for new_token in self.new_tokens: 439 for new_token in self.new_tokens:
440 text_encoder.text_model.embeddings.save_embed( 440 text_encoder.text_model.embeddings.save_embed(
441 new_token.multi_ids, 441 new_token.multi_ids,
442 f"{slugify(new_token.token)}_{step}_{postfix}.bin" 442 checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin")
443 ) 443 )
444 444
445 del text_encoder 445 del text_encoder
446 del learned_embeds
447 446
448 @torch.no_grad() 447 @torch.no_grad()
449 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 448 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):