From dd02ace41f69541044e9db106feaa76bf02da8f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 12 Dec 2022 08:05:06 +0100 Subject: Dreambooth: Support loading Textual Inversion embeddings --- dreambooth.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 675320b..3110c6d 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -117,6 +117,12 @@ def parse_args(): default="output/dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--embeddings_dir", + type=str, + default="embeddings_ti", + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) parser.add_argument( "--seed", type=int, @@ -521,7 +527,7 @@ class Checkpointer: negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, - latents_or_image=latents[:len(prompt)] if latents is not None else None, + image=latents[:len(prompt)] if latents is not None else None, generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, @@ -567,6 +573,8 @@ def main(): basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) + embeddings_dir = Path(args.embeddings_dir) + accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", @@ -630,15 +638,25 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + token_embeds = text_encoder.get_input_embeddings().weight.data + print(f"Token ID mappings:") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): print(f"- {token_id} {token}") - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) + embedding_file = embeddings_dir.joinpath(f"{token}.bin") + if embedding_file.exists() and embedding_file.is_file(): + embedding_data = torch.load(embedding_file, map_location="cpu") + + emb = next(iter(embedding_data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + + token_embeds[token_id] = emb - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data original_token_embeds = token_embeds.detach().clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) @@ -959,8 +977,6 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - del timesteps, noise, latents, noisy_latents, encoder_hidden_states - if args.num_class_images != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -977,6 +993,8 @@ def main(): else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + acc = (model_pred == latents).float().mean() + accelerator.backward(loss) if not args.train_text_encoder: @@ -1004,8 +1022,6 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - acc = (model_pred == latents).float().mean() - avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) @@ -1069,8 +1085,6 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - del timesteps, noise, latents, noisy_latents, encoder_hidden_states - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") acc = (model_pred == latents).float().mean() -- cgit v1.2.3-54-g00ecf