From b33ac00de283fe45edba689990dc96a5de93cd1e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 09:40:34 +0100 Subject: Add support for resume in Textual Inversion --- dreambooth.py | 49 ++++----- infer.py | 2 +- .../stable_diffusion/vlpn_stable_diffusion.py | 42 ++++---- textual_inversion.py | 119 ++++++++++----------- 4 files changed, 101 insertions(+), 111 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 3110c6d..9a6f70a 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -13,7 +13,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image @@ -204,7 +204,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_epochs", type=int, - default=20, + default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -558,11 +558,11 @@ class Checkpointer: def main(): args = parse_args() - # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - # raise ValueError( - # "Gradient accumulation is not supported when training the text encoder in distributed training. " - # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - # ) + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) instance_identifier = args.instance_identifier @@ -645,9 +645,9 @@ def main(): print(f"Token ID mappings:") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - print(f"- {token_id} {token}") - embedding_file = embeddings_dir.joinpath(f"{token}.bin") + embedding_source = "init" + if embedding_file.exists() and embedding_file.is_file(): embedding_data = torch.load(embedding_file, map_location="cpu") @@ -656,8 +656,11 @@ def main(): emb = emb.unsqueeze(0) token_embeds[token_id] = emb + embedding_source = "file" - original_token_embeds = token_embeds.detach().clone().to(accelerator.device) + print(f"- {token_id} {token} ({embedding_source})") + + original_token_embeds = token_embeds.clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): @@ -946,7 +949,7 @@ def main(): sample_checkpoint = False for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(itertools.chain(unet, text_encoder)): + with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -997,16 +1000,6 @@ def main(): accelerator.backward(loss) - if not args.train_text_encoder: - # Keep the token embeddings fixed except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - token_embeds = text_encoder.module.get_input_embeddings().weight - else: - token_embeds = text_encoder.get_input_embeddings().weight - - token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] - if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -1022,6 +1015,12 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) + if not args.train_text_encoder: + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + text_encoder.get_input_embeddings( + ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] + avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) @@ -1032,9 +1031,6 @@ def main(): global_step += 1 - if global_step % args.sample_frequency == 0: - sample_checkpoint = True - logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), @@ -1117,8 +1113,9 @@ def main(): f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") max_acc_val = avg_acc_val.avg.item() - if sample_checkpoint and accelerator.is_main_process: - checkpointer.save_samples(global_step, args.sample_steps) + if accelerator.is_main_process: + if epoch % args.sample_frequency == 0: + checkpointer.save_samples(global_step, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: diff --git a/infer.py b/infer.py index 5bd926a..f607041 100644 --- a/infer.py +++ b/infer.py @@ -31,7 +31,7 @@ torch.backends.cudnn.benchmark = True default_args = { - "model": None, + "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "output_dir": "output/inference", diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 141b9a7..707b639 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -421,25 +421,29 @@ class VlpnStableDiffusion(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) diff --git a/textual_inversion.py b/textual_inversion.py index a9c3326..11babd8 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -170,7 +170,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="one_cycle", + default="constant_with_warmup", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' @@ -231,14 +231,14 @@ def parse_args(): parser.add_argument( "--checkpoint_frequency", type=int, - default=500, - help="How often to save a checkpoint and sample image", + default=5, + help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_frequency", type=int, - default=100, - help="How often to save a checkpoint and sample image", + default=1, + help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_image_size", @@ -294,10 +294,9 @@ def parse_args(): help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" ) parser.add_argument( - "--resume_checkpoint", - type=str, - default=None, - help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." + "--global_step", + type=int, + default=0, ) parser.add_argument( "--config", @@ -512,19 +511,10 @@ def main(): if len(args.placeholder_token) != 0: instance_identifier = instance_identifier.format(args.placeholder_token[0]) - global_step_offset = 0 - if args.resume_from is not None: - basepath = Path(args.resume_from) - print("Resuming state from %s" % args.resume_from) - with open(basepath.joinpath("resume.json"), 'r') as f: - state = json.load(f) - global_step_offset = state["args"].get("global_step", 0) - - print("We've trained %d steps so far" % global_step_offset) - else: - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) - basepath.mkdir(parents=True, exist_ok=True) + global_step_offset = args.global_step + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) + basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, @@ -557,6 +547,7 @@ def main(): set_use_memory_efficient_attention_xformers(vae, True) if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() print(f"Adding text embeddings: {args.placeholder_token}") @@ -577,14 +568,25 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - original_token_embeds = token_embeds.detach().clone().to(accelerator.device) - if args.resume_checkpoint is not None: - token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] - else: - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): - token_embeds[token_id] = embeddings + if args.resume_from: + resumepath = Path(args.resume_from).joinpath("checkpoints") + + for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): + embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") + embedding_data = torch.load(embedding_file, map_location="cpu") + + emb = next(iter(embedding_data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + + token_embeds[token_id] = emb + + original_token_embeds = token_embeds.clone().to(accelerator.device) + + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): + token_embeds[token_id] = embeddings index_fixed_tokens = torch.arange(len(tokenizer)) index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] @@ -891,21 +893,16 @@ def main(): accelerator.backward(loss) - # Keep the token embeddings fixed except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - token_embeds = text_encoder.module.get_input_embeddings().weight - else: - token_embeds = text_encoder.get_input_embeddings().weight - - # Get the index for tokens that we want to freeze - token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] - optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + text_encoder.get_input_embeddings( + ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] + loss = loss.detach().item() train_loss += loss @@ -916,19 +913,6 @@ def main(): global_step += 1 - if global_step % args.sample_frequency == 0: - sample_checkpoint = True - - if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: - local_progress_bar.clear() - global_progress_bar.clear() - - checkpointer.checkpoint(global_step + global_step_offset, "training") - save_args(basepath, args, { - "global_step": global_step + global_step_offset, - "resume_checkpoint": f"{basepath}/checkpoints/last.bin" - }) - logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} accelerator.log(logs, step=global_step) @@ -992,24 +976,30 @@ def main(): local_progress_bar.clear() global_progress_bar.clear() - if min_val_loss > val_loss: - accelerator.print( - f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") - checkpointer.checkpoint(global_step + global_step_offset, "milestone") - min_val_loss = val_loss + if accelerator.is_main_process: + if min_val_loss > val_loss: + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone") + min_val_loss = val_loss + + if epoch % args.checkpoint_frequency == 0: + checkpointer.checkpoint(global_step + global_step_offset, "training") + save_args(basepath, args, { + "global_step": global_step + global_step_offset + }) - if sample_checkpoint and accelerator.is_main_process: - checkpointer.save_samples( - global_step + global_step_offset, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + if epoch % args.sample_frequency == 0: + checkpointer.save_samples( + global_step + global_step_offset, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint(global_step + global_step_offset, "end") save_args(basepath, args, { - "global_step": global_step + global_step_offset, - "resume_checkpoint": f"{basepath}/checkpoints/last.bin" + "global_step": global_step + global_step_offset }) accelerator.end_training() @@ -1018,8 +1008,7 @@ def main(): print("Interrupted, saving checkpoint and resume state...") checkpointer.checkpoint(global_step + global_step_offset, "end") save_args(basepath, args, { - "global_step": global_step + global_step_offset, - "resume_checkpoint": f"{basepath}/checkpoints/last.bin" + "global_step": global_step + global_step_offset }) accelerator.end_training() quit() -- cgit v1.2.3-70-g09d2