From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- infer.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 4648c0a..7346de9 100644 --- a/infer.py +++ b/infer.py @@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from util.files import load_config, load_embeddings_from_dir +from util.ti import load_embeddings torch.backends.cuda.matmul.allow_tf32 = True @@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def load_embeddings(pipeline, embeddings_dir): +def load_embeddings_dir(pipeline, embeddings_dir): added_tokens, added_ids = load_embeddings_from_dir( pipeline.tokenizer, pipeline.text_encoder.text_model.embeddings, @@ -258,6 +259,9 @@ def load_lora(pipeline, path): text_encoder_lora_ds = { k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k } + ti_lora_ds = { + k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k + } unet_config = LoraConfig(**lora_config["peft_config"]) pipeline.unet = LoraModel(unet_config, pipeline.unet) @@ -268,6 +272,18 @@ def load_lora(pipeline, path): pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) + tokens = [k for k, _ in ti_lora_ds] + token_embeddings = [v for _, v in ti_lora_ds] + + added_tokens, added_ids = load_embeddings( + tokenizer=pipeline.tokenizer, + embeddings=pipeline.text_encoder.text_model.embeddings, + tokens=tokens, + token_embeddings=token_embeddings, + ) + pipeline.text_encoder.text_model.embeddings.persist() + print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") + return @@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd): return True if elements[0] == 'reload_embeddings': - load_embeddings(self.pipeline, self.ti_embeddings_dir) + load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) return try: @@ -475,7 +491,7 @@ def main(): pipeline = create_pipeline(args.model, dtype) - load_embeddings(pipeline, args.ti_embeddings_dir) + load_embeddings_dir(pipeline, args.ti_embeddings_dir) load_lora(pipeline, args.lora_embedding) # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) -- cgit v1.2.3-54-g00ecf