From 2dfd1790078753f19ca8c585ac77079f3114f3a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 21:47:06 +0100 Subject: Training update --- training/strategy/ti.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9d39e15..00f3529 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -27,7 +27,8 @@ def textual_inversion_strategy( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - output_dir: Path, + sample_output_dir: Path, + checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], @@ -47,6 +48,9 @@ def textual_inversion_strategy( sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): + sample_output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_output_dir.mkdir(parents=True, exist_ok=True) + weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -64,7 +68,7 @@ def textual_inversion_strategy( train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=weight_dtype, - output_dir=output_dir, + output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, @@ -135,14 +139,11 @@ def textual_inversion_strategy( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - checkpoints_path = output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) @torch.no_grad() -- cgit v1.2.3-54-g00ecf