diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 41 |
1 files changed, 17 insertions, 24 deletions
diff --git a/train_ti.py b/train_ti.py index f622299..9aab00c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -280,7 +280,7 @@ def parse_args(): | |||
280 | parser.add_argument( | 280 | parser.add_argument( |
281 | "--ema_power", | 281 | "--ema_power", |
282 | type=float, | 282 | type=float, |
283 | default=6/7 | 283 | default=7/8 |
284 | ) | 284 | ) |
285 | parser.add_argument( | 285 | parser.add_argument( |
286 | "--ema_max_decay", | 286 | "--ema_max_decay", |
@@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase): | |||
464 | def __init__( | 464 | def __init__( |
465 | self, | 465 | self, |
466 | weight_dtype, | 466 | weight_dtype, |
467 | datamodule, | 467 | accelerator: Accelerator, |
468 | accelerator, | 468 | vae: AutoencoderKL, |
469 | vae, | 469 | unet: UNet2DConditionModel, |
470 | unet, | 470 | tokenizer: MultiCLIPTokenizer, |
471 | tokenizer, | 471 | text_encoder: CLIPTextModel, |
472 | text_encoder, | 472 | ema_embeddings: EMAModel, |
473 | ema_embeddings, | ||
474 | scheduler, | 473 | scheduler, |
475 | placeholder_token, | 474 | placeholder_token, |
476 | new_ids, | 475 | new_ids, |
477 | output_dir: Path, | 476 | *args, |
478 | sample_image_size, | 477 | **kwargs |
479 | sample_batches, | ||
480 | sample_batch_size, | ||
481 | seed | ||
482 | ): | 478 | ): |
483 | super().__init__( | 479 | super().__init__(*args, **kwargs) |
484 | datamodule=datamodule, | ||
485 | output_dir=output_dir, | ||
486 | sample_image_size=sample_image_size, | ||
487 | seed=seed or torch.random.seed(), | ||
488 | sample_batches=sample_batches, | ||
489 | sample_batch_size=sample_batch_size | ||
490 | ) | ||
491 | 480 | ||
492 | self.weight_dtype = weight_dtype | 481 | self.weight_dtype = weight_dtype |
493 | self.accelerator = accelerator | 482 | self.accelerator = accelerator |
@@ -829,7 +818,9 @@ def main(): | |||
829 | # Move vae and unet to device | 818 | # Move vae and unet to device |
830 | vae.to(accelerator.device, dtype=weight_dtype) | 819 | vae.to(accelerator.device, dtype=weight_dtype) |
831 | unet.to(accelerator.device, dtype=weight_dtype) | 820 | unet.to(accelerator.device, dtype=weight_dtype) |
832 | ema_embeddings.to(accelerator.device) | 821 | |
822 | if args.use_ema: | ||
823 | ema_embeddings.to(accelerator.device) | ||
833 | 824 | ||
834 | # Keep vae and unet in eval mode as we don't train these | 825 | # Keep vae and unet in eval mode as we don't train these |
835 | vae.eval() | 826 | vae.eval() |
@@ -854,13 +845,15 @@ def main(): | |||
854 | tokenizer.train() | 845 | tokenizer.train() |
855 | yield | 846 | yield |
856 | finally: | 847 | finally: |
857 | tokenizer.eval() | 848 | pass |
858 | 849 | ||
859 | @contextmanager | 850 | @contextmanager |
860 | def on_eval(): | 851 | def on_eval(): |
861 | try: | 852 | try: |
853 | tokenizer.eval() | ||
854 | |||
862 | ema_context = ema_embeddings.apply_temporary( | 855 | ema_context = ema_embeddings.apply_temporary( |
863 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() | 856 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() |
864 | 857 | ||
865 | with ema_context: | 858 | with ema_context: |
866 | yield | 859 | yield |