diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 22:03:01 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:03:01 +0100 |
| commit | fc11c86142915d6c3935d28a3321b3ae91b613ef (patch) | |
| tree | 5d2c84b1ff32e779db868da1248ed24a97cde3c2 /trainer | |
| parent | WIP: Modularization ("free(): invalid pointer" my ass) (diff) | |
| download | textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.gz textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.bz2 textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.zip | |
Update
Diffstat (limited to 'trainer')
| -rw-r--r-- | trainer/base.py | 2 | ||||
| -rw-r--r-- | trainer/ti.py | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/trainer/base.py b/trainer/base.py index e700dd6..1f85e71 100644 --- a/trainer/base.py +++ b/trainer/base.py | |||
| @@ -74,7 +74,7 @@ class Checkpointer(): | |||
| 74 | def checkpoint(self, step: int, postfix: str): | 74 | def checkpoint(self, step: int, postfix: str): |
| 75 | pass | 75 | pass |
| 76 | 76 | ||
| 77 | @torch.inference_mode() | 77 | @torch.no_grad() |
| 78 | def save_samples(self, step: int): | 78 | def save_samples(self, step: int): |
| 79 | print(f"Saving samples for step {step}...") | 79 | print(f"Saving samples for step {step}...") |
| 80 | 80 | ||
diff --git a/trainer/ti.py b/trainer/ti.py index 15cf747..388acd3 100644 --- a/trainer/ti.py +++ b/trainer/ti.py | |||
| @@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
| 43 | ) | 43 | ) |
| 44 | 44 | ||
| 45 | @torch.inference_mode() | 45 | @torch.no_grad() |
| 46 | def save_samples(self, step): | 46 | def save_samples(self, step): |
| 47 | ema_context = self.ema_embeddings.apply_temporary( | 47 | ema_context = self.ema_embeddings.apply_temporary( |
| 48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
