From 03303d3bddba5a27a123babdf90863e27501e6f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 23:09:25 +0100 Subject: Unified loading of TI embeddings --- common.py | 36 ++++++++++++++++++++++++++++++++++++ dreambooth.py | 26 +++++--------------------- infer.py | 34 ++-------------------------------- textual_inversion.py | 23 +++++++++++------------ 4 files changed, 54 insertions(+), 65 deletions(-) create mode 100644 common.py diff --git a/common.py b/common.py new file mode 100644 index 0000000..8d6b55d --- /dev/null +++ b/common.py @@ -0,0 +1,36 @@ +from pathlib import Path +import torch + +from transformers import CLIPTextModel, CLIPTokenizer + + +def load_text_embedding(embeddings, token_id, file): + data = torch.load(file, map_location="cpu") + + assert len(data.keys()) == 1, 'embedding data has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + + embeddings[token_id] = emb + + +def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + return 0 + + files = [file for file in embeddings_dir.iterdir() if file.is_file()] + + tokens = [file.stem for file in files] + added = tokenizer.add_tokens(tokens) + token_ids = tokenizer.convert_tokens_to_ids(tokens) + + text_encoder.resize_token_embeddings(len(tokenizer)) + + token_embeds = text_encoder.get_input_embeddings().weight.data + + for (token_id, file) in zip(token_ids, files): + load_text_embedding(token_embeds, token_id, file) + + return added diff --git a/dreambooth.py b/dreambooth.py index 5521b21..3f45754 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -21,6 +21,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify +from common import load_text_embeddings from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule @@ -125,7 +126,7 @@ def parse_args(): parser.add_argument( "--embeddings_dir", type=str, - default="embeddings_ti", + default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( @@ -578,8 +579,6 @@ def main(): basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) - embeddings_dir = Path(args.embeddings_dir) - accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", @@ -629,6 +628,9 @@ def main(): # Freeze text_encoder and vae vae.requires_grad_(False) + if args.embeddings_dir is not None: + load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ @@ -645,24 +647,6 @@ def main(): text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data - - print(f"Token ID mappings:") - for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - embedding_file = embeddings_dir.joinpath(f"{token}.bin") - embedding_source = "init" - - if embedding_file.exists() and embedding_file.is_file(): - embedding_data = torch.load(embedding_file, map_location="cpu") - - emb = next(iter(embedding_data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - token_embeds[token_id] = emb - embedding_source = "file" - - print(f"- {token_id} {token} ({embedding_source})") - original_token_embeds = token_embeds.clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) diff --git a/infer.py b/infer.py index f607041..1fd11e2 100644 --- a/infer.py +++ b/infer.py @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from common import load_text_embeddings torch.backends.cuda.matmul.allow_tf32 = True @@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): - print(f"Loading Textual Inversion embeddings") - - embeddings_dir = Path(embeddings_dir) - embeddings_dir.mkdir(parents=True, exist_ok=True) - - placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] - tokenizer.add_tokens(placeholder_tokens) - - text_encoder.resize_token_embeddings(len(tokenizer)) - - token_embeds = text_encoder.get_input_embeddings().weight.data - - for file in embeddings_dir.iterdir(): - if file.is_file(): - placeholder_token = file.stem - placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) - - data = torch.load(file, map_location="cpu") - - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - token_embeds[placeholder_token_id] = emb - - print(f"Loaded {placeholder_token}") - - def create_pipeline(model, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") @@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) + load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, diff --git a/textual_inversion.py b/textual_inversion.py index fd4a313..6d8fd77 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -22,6 +22,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify +from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule @@ -104,6 +105,12 @@ def parse_args(): default="output/text-inversion", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=None, + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) parser.add_argument( "--seed", type=int, @@ -551,6 +558,9 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + if args.embeddings_dir is not None: + load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -562,10 +572,6 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) - print(f"Token ID mappings:") - for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - print(f"- {token_id} {token}") - # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -576,14 +582,7 @@ def main(): resumepath = Path(args.resume_from).joinpath("checkpoints") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") - embedding_data = torch.load(embedding_file, map_location="cpu") - - emb = next(iter(embedding_data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - token_embeds[token_id] = emb + load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) original_token_embeds = token_embeds.clone().to(accelerator.device) -- cgit v1.2.3-54-g00ecf