summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py39
1 files changed, 37 insertions, 2 deletions
diff --git a/infer.py b/infer.py
index ed86ab1..93848d7 100644
--- a/infer.py
+++ b/infer.py
@@ -26,6 +26,8 @@ from diffusers import (
26 DEISMultistepScheduler, 26 DEISMultistepScheduler,
27 UniPCMultistepScheduler 27 UniPCMultistepScheduler
28) 28)
29from peft import LoraConfig, LoraModel, set_peft_model_state_dict
30from safetensors.torch import load_file
29from transformers import CLIPTextModel 31from transformers import CLIPTextModel
30 32
31from data.keywords import str_to_keywords, keywords_to_str 33from data.keywords import str_to_keywords, keywords_to_str
@@ -43,7 +45,7 @@ default_args = {
43 "model": "stabilityai/stable-diffusion-2-1", 45 "model": "stabilityai/stable-diffusion-2-1",
44 "precision": "fp32", 46 "precision": "fp32",
45 "ti_embeddings_dir": "embeddings_ti", 47 "ti_embeddings_dir": "embeddings_ti",
46 "lora_embeddings_dir": "embeddings_lora", 48 "lora_embedding": None,
47 "output_dir": "output/inference", 49 "output_dir": "output/inference",
48 "config": None, 50 "config": None,
49} 51}
@@ -99,7 +101,7 @@ def create_args_parser():
99 type=str, 101 type=str,
100 ) 102 )
101 parser.add_argument( 103 parser.add_argument(
102 "--lora_embeddings_dir", 104 "--lora_embedding",
103 type=str, 105 type=str,
104 ) 106 )
105 parser.add_argument( 107 parser.add_argument(
@@ -236,6 +238,38 @@ def load_embeddings(pipeline, embeddings_dir):
236 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 238 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
237 239
238 240
241def load_lora(pipeline, path):
242 if path is None:
243 return
244
245 path = Path(path)
246
247 with open(path / "lora_config.json", "r") as f:
248 lora_config = json.load(f)
249
250 tensor_files = list(path.glob("*_end.safetensors"))
251
252 if len(tensor_files) == 0:
253 return
254
255 lora_checkpoint_sd = load_file(path / tensor_files[0])
256 unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
257 text_encoder_lora_ds = {
258 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
259 }
260
261 unet_config = LoraConfig(**lora_config["peft_config"])
262 pipeline.unet = LoraModel(unet_config, pipeline.unet)
263 set_peft_model_state_dict(pipeline.unet, unet_lora_ds)
264
265 if "text_encoder_peft_config" in lora_config:
266 text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
267 pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder)
268 set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds)
269
270 return
271
272
239def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): 273def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None):
240 if scheduler == "plms": 274 if scheduler == "plms":
241 return PNDMScheduler.from_config(config) 275 return PNDMScheduler.from_config(config)
@@ -441,6 +475,7 @@ def main():
441 pipeline = create_pipeline(args.model, dtype) 475 pipeline = create_pipeline(args.model, dtype)
442 476
443 load_embeddings(pipeline, args.ti_embeddings_dir) 477 load_embeddings(pipeline, args.ti_embeddings_dir)
478 load_lora(pipeline, args.lora_embedding)
444 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 479 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
445 480
446 cmd_parser = create_cmd_parser() 481 cmd_parser = create_cmd_parser()