From 2dfd1790078753f19ca8c585ac77079f3114f3a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 21:47:06 +0100 Subject: Training update --- train_dreambooth.py | 6 +++++- train_ti.py | 8 +++++--- training/functional.py | 4 +--- training/strategy/dreambooth.py | 10 +++++++--- training/strategy/ti.py | 13 +++++++------ 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 6511f9b..d722e68 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -552,6 +552,9 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) + checkpoint_output_dir = output_dir.joinpath("model") + sample_output_dir = output_dir.joinpath(f"samples") + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -620,7 +623,8 @@ def main(): # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, - output_dir=output_dir, + sample_output_dir=sample_output_dir, + checkpoint_output_dir=checkpoint_output_dir, train_text_encoder_epochs=args.train_text_encoder_epochs, max_grad_norm=args.max_grad_norm, use_ema=args.use_ema, diff --git a/train_ti.py b/train_ti.py index e696577..e7aeb23 100644 --- a/train_ti.py +++ b/train_ti.py @@ -572,6 +572,8 @@ def main(): callbacks_fn=textual_inversion_strategy ) + checkpoint_output_dir = output_dir.joinpath("checkpoints") + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( range(len(args.placeholder_tokens)), args.placeholder_tokens, @@ -579,8 +581,7 @@ def main(): args.num_vectors, args.train_data_template ): - cur_subdir = output_dir.joinpath(placeholder_token) - cur_subdir.mkdir(parents=True, exist_ok=True) + sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -655,7 +656,8 @@ def main(): # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, - output_dir=cur_subdir, + sample_output_dir=sample_output_dir, + checkpoint_output_dir=checkpoint_output_dir, placeholder_tokens=[placeholder_token], placeholder_token_ids=placeholder_token_ids, learning_rate=args.learning_rate, diff --git a/training/functional.py b/training/functional.py index 1548784..3d27380 100644 --- a/training/functional.py +++ b/training/functional.py @@ -87,8 +87,6 @@ def save_samples( ): print(f"Saving samples for step {step}...") - samples_path = output_dir.joinpath("samples") - grid_cols = min(batch_size, 4) grid_rows = (num_batches * batch_size) // grid_cols @@ -120,7 +118,7 @@ def save_samples( for pool, data, gen in datasets: all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path = output_dir.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index aeaa828..93c81cb 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -27,7 +27,8 @@ def dreambooth_strategy( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - output_dir: Path, + sample_output_dir: Path, + checkpoint_output_dir: Path, seed: int, train_text_encoder_epochs: int, max_grad_norm: float = 1.0, @@ -47,6 +48,9 @@ def dreambooth_strategy( "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) + sample_output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_output_dir.mkdir(parents=True, exist_ok=True) + weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -64,7 +68,7 @@ def dreambooth_strategy( train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=weight_dtype, - output_dir=output_dir, + output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, @@ -156,7 +160,7 @@ def dreambooth_strategy( tokenizer=tokenizer, scheduler=sample_scheduler, ) - pipeline.save_pretrained(output_dir.joinpath("model")) + pipeline.save_pretrained(checkpoint_output_dir) del unet_ del text_encoder_ diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9d39e15..00f3529 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -27,7 +27,8 @@ def textual_inversion_strategy( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - output_dir: Path, + sample_output_dir: Path, + checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], @@ -47,6 +48,9 @@ def textual_inversion_strategy( sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): + sample_output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_output_dir.mkdir(parents=True, exist_ok=True) + weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -64,7 +68,7 @@ def textual_inversion_strategy( train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=weight_dtype, - output_dir=output_dir, + output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, @@ -135,14 +139,11 @@ def textual_inversion_strategy( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - checkpoints_path = output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) @torch.no_grad() -- cgit v1.2.3-70-g09d2