diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 22:03:01 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:03:01 +0100 | 
| commit | fc11c86142915d6c3935d28a3321b3ae91b613ef (patch) | |
| tree | 5d2c84b1ff32e779db868da1248ed24a97cde3c2 /train_ti.py | |
| parent | WIP: Modularization ("free(): invalid pointer" my ass) (diff) | |
| download | textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.gz textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.bz2 textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 7 | 
1 files changed, 4 insertions, 3 deletions
| diff --git a/train_ti.py b/train_ti.py index deed84c..a4e2dde 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 
| 513 | ) | 513 | ) | 
| 514 | 514 | ||
| 515 | @torch.inference_mode() | 515 | @torch.no_grad() | 
| 516 | def save_samples(self, step): | 516 | def save_samples(self, step): | 
| 517 | ema_context = self.ema_embeddings.apply_temporary( | 517 | ema_context = self.ema_embeddings.apply_temporary( | 
| 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 
| @@ -808,7 +808,6 @@ def main(): | |||
| 808 | optimizer=optimizer, | 808 | optimizer=optimizer, | 
| 809 | lr_scheduler=lr_scheduler, | 809 | lr_scheduler=lr_scheduler, | 
| 810 | model=text_encoder, | 810 | model=text_encoder, | 
| 811 | checkpointer=checkpointer, | ||
| 812 | train_dataloader=train_dataloader, | 811 | train_dataloader=train_dataloader, | 
| 813 | val_dataloader=val_dataloader, | 812 | val_dataloader=val_dataloader, | 
| 814 | loss_step=loss_step_, | 813 | loss_step=loss_step_, | 
| @@ -819,7 +818,9 @@ def main(): | |||
| 819 | on_log=on_log, | 818 | on_log=on_log, | 
| 820 | on_train=on_train, | 819 | on_train=on_train, | 
| 821 | on_after_optimize=on_after_optimize, | 820 | on_after_optimize=on_after_optimize, | 
| 822 | on_eval=on_eval | 821 | on_eval=on_eval, | 
| 822 | on_sample=checkpointer.save_samples, | ||
| 823 | on_checkpoint=checkpointer.checkpoint, | ||
| 823 | ) | 824 | ) | 
| 824 | 825 | ||
| 825 | 826 | ||
