From 03303d3bddba5a27a123babdf90863e27501e6f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 23:09:25 +0100 Subject: Unified loading of TI embeddings --- infer.py | 34 ++-------------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index f607041..1fd11e2 100644 --- a/infer.py +++ b/infer.py @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from common import load_text_embeddings torch.backends.cuda.matmul.allow_tf32 = True @@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): - print(f"Loading Textual Inversion embeddings") - - embeddings_dir = Path(embeddings_dir) - embeddings_dir.mkdir(parents=True, exist_ok=True) - - placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] - tokenizer.add_tokens(placeholder_tokens) - - text_encoder.resize_token_embeddings(len(tokenizer)) - - token_embeds = text_encoder.get_input_embeddings().weight.data - - for file in embeddings_dir.iterdir(): - if file.is_file(): - placeholder_token = file.stem - placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) - - data = torch.load(file, map_location="cpu") - - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - token_embeds[placeholder_token_id] = emb - - print(f"Loaded {placeholder_token}") - - def create_pipeline(model, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") @@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) + load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, -- cgit v1.2.3-54-g00ecf