From dd02ace41f69541044e9db106feaa76bf02da8f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 12 Dec 2022 08:05:06 +0100 Subject: Dreambooth: Support loading Textual Inversion embeddings --- dreambooth.py | 36 ++++++++++++++------- infer.py | 2 +- .../stable_diffusion/vlpn_stable_diffusion.py | 13 +++++--- textual_inversion.py | 37 ++++++++++++++-------- 4 files changed, 57 insertions(+), 31 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 675320b..3110c6d 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -117,6 +117,12 @@ def parse_args(): default="output/dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--embeddings_dir", + type=str, + default="embeddings_ti", + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) parser.add_argument( "--seed", type=int, @@ -521,7 +527,7 @@ class Checkpointer: negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, - latents_or_image=latents[:len(prompt)] if latents is not None else None, + image=latents[:len(prompt)] if latents is not None else None, generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, @@ -567,6 +573,8 @@ def main(): basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) + embeddings_dir = Path(args.embeddings_dir) + accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", @@ -630,15 +638,25 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + token_embeds = text_encoder.get_input_embeddings().weight.data + print(f"Token ID mappings:") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): print(f"- {token_id} {token}") - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) + embedding_file = embeddings_dir.joinpath(f"{token}.bin") + if embedding_file.exists() and embedding_file.is_file(): + embedding_data = torch.load(embedding_file, map_location="cpu") + + emb = next(iter(embedding_data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + + token_embeds[token_id] = emb - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data original_token_embeds = token_embeds.detach().clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) @@ -959,8 +977,6 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - del timesteps, noise, latents, noisy_latents, encoder_hidden_states - if args.num_class_images != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -977,6 +993,8 @@ def main(): else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + acc = (model_pred == latents).float().mean() + accelerator.backward(loss) if not args.train_text_encoder: @@ -1004,8 +1022,6 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - acc = (model_pred == latents).float().mean() - avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) @@ -1069,8 +1085,6 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - del timesteps, noise, latents, noisy_latents, encoder_hidden_states - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") acc = (model_pred == latents).float().mean() diff --git a/infer.py b/infer.py index e3fa9e5..5bd926a 100644 --- a/infer.py +++ b/infer.py @@ -291,7 +291,7 @@ def generate(output_dir, pipeline, args): num_inference_steps=args.steps, guidance_scale=args.guidance_scale, generator=generator, - latents_or_image=init_image, + image=init_image, strength=args.image_noise, ).images diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 78a34d5..141b9a7 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def preprocess(image, w, h): +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, - latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): batch_size = len(prompt) device = self.execution_device do_classifier_free_guidance = guidance_scale > 1.0 - latents_are_image = isinstance(latents_or_image, PIL.Image.Image) + latents_are_image = isinstance(image, PIL.Image.Image) # 3. Encode input prompt text_embeddings = self.encode_prompt( @@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline): # 5. Prepare latent variables num_channels_latents = self.unet.in_channels if latents_are_image: + image = preprocess(image) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latents = self.prepare_latents_from_image( - latents_or_image, + image, latent_timestep, batch_size, num_images_per_prompt, @@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline): text_embeddings.dtype, device, generator, - latents_or_image, + image, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline diff --git a/textual_inversion.py b/textual_inversion.py index da7c747..a9c3326 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -107,7 +107,7 @@ def parse_args(): parser.add_argument( "--resolution", type=int, - default=512, + default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" @@ -118,6 +118,12 @@ def parse_args(): action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0.1, + help="Tag dropout probability.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -171,9 +177,9 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", + "--lr_warmup_epochs", type=int, - default=300, + default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -237,7 +243,7 @@ def parse_args(): parser.add_argument( "--sample_image_size", type=int, - default=512, + default=768, help="Size of sample images", ) parser.add_argument( @@ -267,7 +273,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=30, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -399,28 +405,28 @@ class Checkpointer: checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - unwrapped = self.accelerator.unwrap_model(self.text_encoder) + text_encoder = self.accelerator.unwrap_model(self.text_encoder) for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] + learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) - del unwrapped + del text_encoder del learned_embeds @torch.no_grad() def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model(self.text_encoder) + text_encoder = self.accelerator.unwrap_model(self.text_encoder) # Save a sample image pipeline = VlpnStableDiffusion( - text_encoder=unwrapped, + text_encoder=text_encoder, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, @@ -471,7 +477,7 @@ class Checkpointer: negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, - latents_or_image=latents[:len(prompt)] if latents is not None else None, + image=latents[:len(prompt)] if latents is not None else None, generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, @@ -489,7 +495,7 @@ class Checkpointer: del all_samples del image_grid - del unwrapped + del text_encoder del pipeline del generator del stable_latents @@ -662,6 +668,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + dropout=args.tag_dropout, center_crop=args.center_crop, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, @@ -720,6 +727,8 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps + if args.lr_scheduler == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, @@ -728,7 +737,7 @@ def main(): elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), @@ -737,7 +746,7 @@ def main(): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) -- cgit v1.2.3-70-g09d2