diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 22:03:01 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:03:01 +0100 |
commit | fc11c86142915d6c3935d28a3321b3ae91b613ef (patch) | |
tree | 5d2c84b1ff32e779db868da1248ed24a97cde3c2 /train_ti.py | |
parent | WIP: Modularization ("free(): invalid pointer" my ass) (diff) | |
download | textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.gz textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.bz2 textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index deed84c..a4e2dde 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
513 | ) | 513 | ) |
514 | 514 | ||
515 | @torch.inference_mode() | 515 | @torch.no_grad() |
516 | def save_samples(self, step): | 516 | def save_samples(self, step): |
517 | ema_context = self.ema_embeddings.apply_temporary( | 517 | ema_context = self.ema_embeddings.apply_temporary( |
518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
@@ -808,7 +808,6 @@ def main(): | |||
808 | optimizer=optimizer, | 808 | optimizer=optimizer, |
809 | lr_scheduler=lr_scheduler, | 809 | lr_scheduler=lr_scheduler, |
810 | model=text_encoder, | 810 | model=text_encoder, |
811 | checkpointer=checkpointer, | ||
812 | train_dataloader=train_dataloader, | 811 | train_dataloader=train_dataloader, |
813 | val_dataloader=val_dataloader, | 812 | val_dataloader=val_dataloader, |
814 | loss_step=loss_step_, | 813 | loss_step=loss_step_, |
@@ -819,7 +818,9 @@ def main(): | |||
819 | on_log=on_log, | 818 | on_log=on_log, |
820 | on_train=on_train, | 819 | on_train=on_train, |
821 | on_after_optimize=on_after_optimize, | 820 | on_after_optimize=on_after_optimize, |
822 | on_eval=on_eval | 821 | on_eval=on_eval, |
822 | on_sample=checkpointer.save_samples, | ||
823 | on_checkpoint=checkpointer.checkpoint, | ||
823 | ) | 824 | ) |
824 | 825 | ||
825 | 826 | ||