diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index 69d15ea..3a5cfed 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -439,11 +439,10 @@ class Checkpointer(CheckpointerBase): | |||
439 | for new_token in self.new_tokens: | 439 | for new_token in self.new_tokens: |
440 | text_encoder.text_model.embeddings.save_embed( | 440 | text_encoder.text_model.embeddings.save_embed( |
441 | new_token.multi_ids, | 441 | new_token.multi_ids, |
442 | f"{slugify(new_token.token)}_{step}_{postfix}.bin" | 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") |
443 | ) | 443 | ) |
444 | 444 | ||
445 | del text_encoder | 445 | del text_encoder |
446 | del learned_embeds | ||
447 | 446 | ||
448 | @torch.no_grad() | 447 | @torch.no_grad() |
449 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 448 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |