diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 39 |
1 files changed, 37 insertions, 2 deletions
@@ -26,6 +26,8 @@ from diffusers import ( | |||
26 | DEISMultistepScheduler, | 26 | DEISMultistepScheduler, |
27 | UniPCMultistepScheduler | 27 | UniPCMultistepScheduler |
28 | ) | 28 | ) |
29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict | ||
30 | from safetensors.torch import load_file | ||
29 | from transformers import CLIPTextModel | 31 | from transformers import CLIPTextModel |
30 | 32 | ||
31 | from data.keywords import str_to_keywords, keywords_to_str | 33 | from data.keywords import str_to_keywords, keywords_to_str |
@@ -43,7 +45,7 @@ default_args = { | |||
43 | "model": "stabilityai/stable-diffusion-2-1", | 45 | "model": "stabilityai/stable-diffusion-2-1", |
44 | "precision": "fp32", | 46 | "precision": "fp32", |
45 | "ti_embeddings_dir": "embeddings_ti", | 47 | "ti_embeddings_dir": "embeddings_ti", |
46 | "lora_embeddings_dir": "embeddings_lora", | 48 | "lora_embedding": None, |
47 | "output_dir": "output/inference", | 49 | "output_dir": "output/inference", |
48 | "config": None, | 50 | "config": None, |
49 | } | 51 | } |
@@ -99,7 +101,7 @@ def create_args_parser(): | |||
99 | type=str, | 101 | type=str, |
100 | ) | 102 | ) |
101 | parser.add_argument( | 103 | parser.add_argument( |
102 | "--lora_embeddings_dir", | 104 | "--lora_embedding", |
103 | type=str, | 105 | type=str, |
104 | ) | 106 | ) |
105 | parser.add_argument( | 107 | parser.add_argument( |
@@ -236,6 +238,38 @@ def load_embeddings(pipeline, embeddings_dir): | |||
236 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 238 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
237 | 239 | ||
238 | 240 | ||
241 | def load_lora(pipeline, path): | ||
242 | if path is None: | ||
243 | return | ||
244 | |||
245 | path = Path(path) | ||
246 | |||
247 | with open(path / "lora_config.json", "r") as f: | ||
248 | lora_config = json.load(f) | ||
249 | |||
250 | tensor_files = list(path.glob("*_end.safetensors")) | ||
251 | |||
252 | if len(tensor_files) == 0: | ||
253 | return | ||
254 | |||
255 | lora_checkpoint_sd = load_file(path / tensor_files[0]) | ||
256 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | ||
257 | text_encoder_lora_ds = { | ||
258 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | ||
259 | } | ||
260 | |||
261 | unet_config = LoraConfig(**lora_config["peft_config"]) | ||
262 | pipeline.unet = LoraModel(unet_config, pipeline.unet) | ||
263 | set_peft_model_state_dict(pipeline.unet, unet_lora_ds) | ||
264 | |||
265 | if "text_encoder_peft_config" in lora_config: | ||
266 | text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) | ||
267 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) | ||
268 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) | ||
269 | |||
270 | return | ||
271 | |||
272 | |||
239 | def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): | 273 | def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): |
240 | if scheduler == "plms": | 274 | if scheduler == "plms": |
241 | return PNDMScheduler.from_config(config) | 275 | return PNDMScheduler.from_config(config) |
@@ -441,6 +475,7 @@ def main(): | |||
441 | pipeline = create_pipeline(args.model, dtype) | 475 | pipeline = create_pipeline(args.model, dtype) |
442 | 476 | ||
443 | load_embeddings(pipeline, args.ti_embeddings_dir) | 477 | load_embeddings(pipeline, args.ti_embeddings_dir) |
478 | load_lora(pipeline, args.lora_embedding) | ||
444 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 479 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
445 | 480 | ||
446 | cmd_parser = create_cmd_parser() | 481 | cmd_parser = create_cmd_parser() |