From 8c068963d4b67c6b894e720288e5863dade8d6e6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 13:09:04 +0100 Subject: Fixes --- models/clip/embeddings.py | 2 +- train_ti.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7d63ffb..f82873e 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -74,7 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): - input_ids = torch.tensor(input_ids) + input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) diff --git a/train_ti.py b/train_ti.py index 69d15ea..3a5cfed 100644 --- a/train_ti.py +++ b/train_ti.py @@ -439,11 +439,10 @@ class Checkpointer(CheckpointerBase): for new_token in self.new_tokens: text_encoder.text_model.embeddings.save_embed( new_token.multi_ids, - f"{slugify(new_token.token)}_{step}_{postfix}.bin" + checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") ) del text_encoder - del learned_embeds @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): -- cgit v1.2.3-54-g00ecf