summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 17:34:23 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 17:34:23 +0100
commitb8df3dd5330845ff9f9f9af187a09ef0dbfc1c20 (patch)
treecd56ce0b92c38a31160d28c6665b7c378f7403dd /train_ti.py
parentUse context manager for EMA, on_train/eval hooks (diff)
downloadtextual-inversion-diff-b8df3dd5330845ff9f9f9af187a09ef0dbfc1c20.tar.gz
textual-inversion-diff-b8df3dd5330845ff9f9f9af187a09ef0dbfc1c20.tar.bz2
textual-inversion-diff-b8df3dd5330845ff9f9f9af187a09ef0dbfc1c20.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py41
1 files changed, 17 insertions, 24 deletions
diff --git a/train_ti.py b/train_ti.py
index f622299..9aab00c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -280,7 +280,7 @@ def parse_args():
280 parser.add_argument( 280 parser.add_argument(
281 "--ema_power", 281 "--ema_power",
282 type=float, 282 type=float,
283 default=6/7 283 default=7/8
284 ) 284 )
285 parser.add_argument( 285 parser.add_argument(
286 "--ema_max_decay", 286 "--ema_max_decay",
@@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase):
464 def __init__( 464 def __init__(
465 self, 465 self,
466 weight_dtype, 466 weight_dtype,
467 datamodule, 467 accelerator: Accelerator,
468 accelerator, 468 vae: AutoencoderKL,
469 vae, 469 unet: UNet2DConditionModel,
470 unet, 470 tokenizer: MultiCLIPTokenizer,
471 tokenizer, 471 text_encoder: CLIPTextModel,
472 text_encoder, 472 ema_embeddings: EMAModel,
473 ema_embeddings,
474 scheduler, 473 scheduler,
475 placeholder_token, 474 placeholder_token,
476 new_ids, 475 new_ids,
477 output_dir: Path, 476 *args,
478 sample_image_size, 477 **kwargs
479 sample_batches,
480 sample_batch_size,
481 seed
482 ): 478 ):
483 super().__init__( 479 super().__init__(*args, **kwargs)
484 datamodule=datamodule,
485 output_dir=output_dir,
486 sample_image_size=sample_image_size,
487 seed=seed or torch.random.seed(),
488 sample_batches=sample_batches,
489 sample_batch_size=sample_batch_size
490 )
491 480
492 self.weight_dtype = weight_dtype 481 self.weight_dtype = weight_dtype
493 self.accelerator = accelerator 482 self.accelerator = accelerator
@@ -829,7 +818,9 @@ def main():
829 # Move vae and unet to device 818 # Move vae and unet to device
830 vae.to(accelerator.device, dtype=weight_dtype) 819 vae.to(accelerator.device, dtype=weight_dtype)
831 unet.to(accelerator.device, dtype=weight_dtype) 820 unet.to(accelerator.device, dtype=weight_dtype)
832 ema_embeddings.to(accelerator.device) 821
822 if args.use_ema:
823 ema_embeddings.to(accelerator.device)
833 824
834 # Keep vae and unet in eval mode as we don't train these 825 # Keep vae and unet in eval mode as we don't train these
835 vae.eval() 826 vae.eval()
@@ -854,13 +845,15 @@ def main():
854 tokenizer.train() 845 tokenizer.train()
855 yield 846 yield
856 finally: 847 finally:
857 tokenizer.eval() 848 pass
858 849
859 @contextmanager 850 @contextmanager
860 def on_eval(): 851 def on_eval():
861 try: 852 try:
853 tokenizer.eval()
854
862 ema_context = ema_embeddings.apply_temporary( 855 ema_context = ema_embeddings.apply_temporary(
863 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() 856 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext()
864 857
865 with ema_context: 858 with ema_context:
866 yield 859 yield