summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 22:03:01 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 22:03:01 +0100
commitfc11c86142915d6c3935d28a3321b3ae91b613ef (patch)
tree5d2c84b1ff32e779db868da1248ed24a97cde3c2 /train_ti.py
parentWIP: Modularization ("free(): invalid pointer" my ass) (diff)
downloadtextual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.gz
textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.bz2
textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index deed84c..a4e2dde 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer):
512 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 512 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
513 ) 513 )
514 514
515 @torch.inference_mode() 515 @torch.no_grad()
516 def save_samples(self, step): 516 def save_samples(self, step):
517 ema_context = self.ema_embeddings.apply_temporary( 517 ema_context = self.ema_embeddings.apply_temporary(
518 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() 518 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
@@ -808,7 +808,6 @@ def main():
808 optimizer=optimizer, 808 optimizer=optimizer,
809 lr_scheduler=lr_scheduler, 809 lr_scheduler=lr_scheduler,
810 model=text_encoder, 810 model=text_encoder,
811 checkpointer=checkpointer,
812 train_dataloader=train_dataloader, 811 train_dataloader=train_dataloader,
813 val_dataloader=val_dataloader, 812 val_dataloader=val_dataloader,
814 loss_step=loss_step_, 813 loss_step=loss_step_,
@@ -819,7 +818,9 @@ def main():
819 on_log=on_log, 818 on_log=on_log,
820 on_train=on_train, 819 on_train=on_train,
821 on_after_optimize=on_after_optimize, 820 on_after_optimize=on_after_optimize,
822 on_eval=on_eval 821 on_eval=on_eval,
822 on_sample=checkpointer.save_samples,
823 on_checkpoint=checkpointer.checkpoint,
823 ) 824 )
824 825
825 826