From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- training/strategy/ti.py | 54 ++++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 28 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6f8384f..753dce0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -27,7 +27,6 @@ def textual_inversion_strategy( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, - dtype: torch.dtype, output_dir: Path, seed: int, placeholder_tokens: list[str], @@ -48,6 +47,12 @@ def textual_inversion_strategy( sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + save_samples_ = partial( save_samples, accelerator=accelerator, @@ -58,7 +63,7 @@ def textual_inversion_strategy( sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=dtype, + dtype=weight_dtype, output_dir=output_dir, seed=seed, batch_size=sample_batch_size, @@ -78,6 +83,17 @@ def textual_inversion_strategy( else: ema_embeddings = None + def ema_context(): + if use_ema: + return ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) + else: + return nullcontext() + + def on_model(): + return text_encoder + def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) @@ -89,24 +105,15 @@ def textual_inversion_strategy( @contextmanager def on_train(epoch: int): - try: - tokenizer.train() - yield - finally: - pass + tokenizer.train() + yield @contextmanager def on_eval(): - try: - tokenizer.eval() + tokenizer.eval() - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() - - with ema_context: - yield - finally: - pass + with ema_context(): + yield @torch.no_grad() def on_after_optimize(lr: float): @@ -131,13 +138,7 @@ def textual_inversion_strategy( checkpoints_path = output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - text_encoder = accelerator.unwrap_model(text_encoder) - - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if ema_embeddings is not None else nullcontext() - - with ema_context: + with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, @@ -146,15 +147,12 @@ def textual_inversion_strategy( @torch.no_grad() def on_sample(step): - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if ema_embeddings is not None else nullcontext() - - with ema_context: + with ema_context(): save_samples_(step=step) return TrainingCallbacks( on_prepare=on_prepare, + on_model=on_model, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, -- cgit v1.2.3-54-g00ecf