diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 13:09:04 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 13:09:04 +0100 |
commit | 8c068963d4b67c6b894e720288e5863dade8d6e6 (patch) | |
tree | 823bf9852244e5adfe6a4f6fe5fcd87e8441e685 | |
parent | Added multi-vector embeddings (diff) | |
download | textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.tar.gz textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.tar.bz2 textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.zip |
Fixes
-rw-r--r-- | models/clip/embeddings.py | 2 | ||||
-rw-r--r-- | train_ti.py | 3 |
2 files changed, 2 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7d63ffb..f82873e 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -74,7 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
74 | 74 | ||
75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
76 | if isinstance(input_ids, list): | 76 | if isinstance(input_ids, list): |
77 | input_ids = torch.tensor(input_ids) | 77 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) |
78 | 78 | ||
79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | 79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) |
80 | 80 | ||
diff --git a/train_ti.py b/train_ti.py index 69d15ea..3a5cfed 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -439,11 +439,10 @@ class Checkpointer(CheckpointerBase): | |||
439 | for new_token in self.new_tokens: | 439 | for new_token in self.new_tokens: |
440 | text_encoder.text_model.embeddings.save_embed( | 440 | text_encoder.text_model.embeddings.save_embed( |
441 | new_token.multi_ids, | 441 | new_token.multi_ids, |
442 | f"{slugify(new_token.token)}_{step}_{postfix}.bin" | 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") |
443 | ) | 443 | ) |
444 | 444 | ||
445 | del text_encoder | 445 | del text_encoder |
446 | del learned_embeds | ||
447 | 446 | ||
448 | @torch.no_grad() | 447 | @torch.no_grad() |
449 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 448 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |