From b33ac00de283fe45edba689990dc96a5de93cd1e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 09:40:34 +0100 Subject: Add support for resume in Textual Inversion --- dreambooth.py | 49 +++++++++++++++++++++++-------------------------- 1 file changed, 23 insertions(+), 26 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 3110c6d..9a6f70a 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -13,7 +13,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image @@ -204,7 +204,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_epochs", type=int, - default=20, + default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -558,11 +558,11 @@ class Checkpointer: def main(): args = parse_args() - # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - # raise ValueError( - # "Gradient accumulation is not supported when training the text encoder in distributed training. " - # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - # ) + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) instance_identifier = args.instance_identifier @@ -645,9 +645,9 @@ def main(): print(f"Token ID mappings:") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - print(f"- {token_id} {token}") - embedding_file = embeddings_dir.joinpath(f"{token}.bin") + embedding_source = "init" + if embedding_file.exists() and embedding_file.is_file(): embedding_data = torch.load(embedding_file, map_location="cpu") @@ -656,8 +656,11 @@ def main(): emb = emb.unsqueeze(0) token_embeds[token_id] = emb + embedding_source = "file" - original_token_embeds = token_embeds.detach().clone().to(accelerator.device) + print(f"- {token_id} {token} ({embedding_source})") + + original_token_embeds = token_embeds.clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): @@ -946,7 +949,7 @@ def main(): sample_checkpoint = False for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(itertools.chain(unet, text_encoder)): + with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -997,16 +1000,6 @@ def main(): accelerator.backward(loss) - if not args.train_text_encoder: - # Keep the token embeddings fixed except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - token_embeds = text_encoder.module.get_input_embeddings().weight - else: - token_embeds = text_encoder.get_input_embeddings().weight - - token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] - if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -1022,6 +1015,12 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) + if not args.train_text_encoder: + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + text_encoder.get_input_embeddings( + ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] + avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) @@ -1032,9 +1031,6 @@ def main(): global_step += 1 - if global_step % args.sample_frequency == 0: - sample_checkpoint = True - logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), @@ -1117,8 +1113,9 @@ def main(): f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") max_acc_val = avg_acc_val.avg.item() - if sample_checkpoint and accelerator.is_main_process: - checkpointer.save_samples(global_step, args.sample_steps) + if accelerator.is_main_process: + if epoch % args.sample_frequency == 0: + checkpointer.save_samples(global_step, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: -- cgit v1.2.3-54-g00ecf