summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index 4648c0a..7346de9 100644
--- a/infer.py
+++ b/infer.py
@@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings
35from models.clip.tokenizer import MultiCLIPTokenizer 35from models.clip.tokenizer import MultiCLIPTokenizer
36from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 36from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
37from util.files import load_config, load_embeddings_from_dir 37from util.files import load_config, load_embeddings_from_dir
38from util.ti import load_embeddings
38 39
39 40
40torch.backends.cuda.matmul.allow_tf32 = True 41torch.backends.cuda.matmul.allow_tf32 = True
@@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}):
229 json.dump(info, f, indent=4) 230 json.dump(info, f, indent=4)
230 231
231 232
232def load_embeddings(pipeline, embeddings_dir): 233def load_embeddings_dir(pipeline, embeddings_dir):
233 added_tokens, added_ids = load_embeddings_from_dir( 234 added_tokens, added_ids = load_embeddings_from_dir(
234 pipeline.tokenizer, 235 pipeline.tokenizer,
235 pipeline.text_encoder.text_model.embeddings, 236 pipeline.text_encoder.text_model.embeddings,
@@ -258,6 +259,9 @@ def load_lora(pipeline, path):
258 text_encoder_lora_ds = { 259 text_encoder_lora_ds = {
259 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k 260 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
260 } 261 }
262 ti_lora_ds = {
263 k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k
264 }
261 265
262 unet_config = LoraConfig(**lora_config["peft_config"]) 266 unet_config = LoraConfig(**lora_config["peft_config"])
263 pipeline.unet = LoraModel(unet_config, pipeline.unet) 267 pipeline.unet = LoraModel(unet_config, pipeline.unet)
@@ -268,6 +272,18 @@ def load_lora(pipeline, path):
268 pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) 272 pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder)
269 set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) 273 set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds)
270 274
275 tokens = [k for k, _ in ti_lora_ds]
276 token_embeddings = [v for _, v in ti_lora_ds]
277
278 added_tokens, added_ids = load_embeddings(
279 tokenizer=pipeline.tokenizer,
280 embeddings=pipeline.text_encoder.text_model.embeddings,
281 tokens=tokens,
282 token_embeddings=token_embeddings,
283 )
284 pipeline.text_encoder.text_model.embeddings.persist()
285 print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}")
286
271 return 287 return
272 288
273 289
@@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd):
435 return True 451 return True
436 452
437 if elements[0] == 'reload_embeddings': 453 if elements[0] == 'reload_embeddings':
438 load_embeddings(self.pipeline, self.ti_embeddings_dir) 454 load_embeddings_dir(self.pipeline, self.ti_embeddings_dir)
439 return 455 return
440 456
441 try: 457 try:
@@ -475,7 +491,7 @@ def main():
475 491
476 pipeline = create_pipeline(args.model, dtype) 492 pipeline = create_pipeline(args.model, dtype)
477 493
478 load_embeddings(pipeline, args.ti_embeddings_dir) 494 load_embeddings_dir(pipeline, args.ti_embeddings_dir)
479 load_lora(pipeline, args.lora_embedding) 495 load_lora(pipeline, args.lora_embedding)
480 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 496 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
481 497