From 03303d3bddba5a27a123babdf90863e27501e6f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 23:09:25 +0100 Subject: Unified loading of TI embeddings --- textual_inversion.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index fd4a313..6d8fd77 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -22,6 +22,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify +from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule @@ -104,6 +105,12 @@ def parse_args(): default="output/text-inversion", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=None, + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) parser.add_argument( "--seed", type=int, @@ -551,6 +558,9 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + if args.embeddings_dir is not None: + load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -562,10 +572,6 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) - print(f"Token ID mappings:") - for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - print(f"- {token_id} {token}") - # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -576,14 +582,7 @@ def main(): resumepath = Path(args.resume_from).joinpath("checkpoints") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") - embedding_data = torch.load(embedding_file, map_location="cpu") - - emb = next(iter(embedding_data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - token_embeds[token_id] = emb + load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) original_token_embeds = token_embeds.clone().to(accelerator.device) -- cgit v1.2.3-54-g00ecf