From 19a013ba9efaad53b7fc0eef647671e7143efc2a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 21:16:08 +0200 Subject: Inference: Add support for embeddings --- infer.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/infer.py b/infer.py index 3dc0f32..3487e5a 100644 --- a/infer.py +++ b/infer.py @@ -23,6 +23,7 @@ default_args = { "model": None, "scheduler": "euler_a", "precision": "bf16", + "embeddings_dir": "embeddings", "output_dir": "output/inference", "config": None, } @@ -71,6 +72,10 @@ def create_args_parser(): type=str, choices=["fp32", "fp16", "bf16"], ) + parser.add_argument( + "--embeddings_dir", + type=str, + ) parser.add_argument( "--output_dir", type=str, @@ -162,7 +167,28 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def create_pipeline(model, scheduler, dtype): +def load_embeddings(tokenizer, text_encoder, embeddings_dir): + embeddings_dir = Path(embeddings_dir) + embeddings_dir.mkdir(parents=True, exist_ok=True) + + token_embeds = text_encoder.get_input_embeddings().weight.data + + for file in embeddings_dir.iterdir(): + placeholder_token = file.stem + placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) + + data = torch.load(file, map_location="cpu") + + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + + token_embeds[placeholder_token_id] = emb + + +def create_pipeline(model, scheduler, embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) @@ -170,6 +196,8 @@ def create_pipeline(model, scheduler, dtype): vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) + load_embeddings(tokenizer, text_encoder, embeddings_dir) + if scheduler == "plms": scheduler = PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True @@ -290,7 +318,7 @@ def main(): output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.scheduler, dtype) + pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() -- cgit v1.2.3-70-g09d2