From b8df3dd5330845ff9f9f9af187a09ef0dbfc1c20 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 17:34:23 +0100 Subject: Update --- train_ti.py | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index f622299..9aab00c 100644 --- a/train_ti.py +++ b/train_ti.py @@ -280,7 +280,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=6/7 + default=7/8 ) parser.add_argument( "--ema_max_decay", @@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase): def __init__( self, weight_dtype, - datamodule, - accelerator, - vae, - unet, - tokenizer, - text_encoder, - ema_embeddings, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, scheduler, placeholder_token, new_ids, - output_dir: Path, - sample_image_size, - sample_batches, - sample_batch_size, - seed + *args, + **kwargs ): - super().__init__( - datamodule=datamodule, - output_dir=output_dir, - sample_image_size=sample_image_size, - seed=seed or torch.random.seed(), - sample_batches=sample_batches, - sample_batch_size=sample_batch_size - ) + super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator @@ -829,7 +818,9 @@ def main(): # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) - ema_embeddings.to(accelerator.device) + + if args.use_ema: + ema_embeddings.to(accelerator.device) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -854,13 +845,15 @@ def main(): tokenizer.train() yield finally: - tokenizer.eval() + pass @contextmanager def on_eval(): try: + tokenizer.eval() + ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() with ema_context: yield -- cgit v1.2.3-54-g00ecf