diff options
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 22 |
1 files changed, 19 insertions, 3 deletions
| @@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings | |||
| 35 | from models.clip.tokenizer import MultiCLIPTokenizer | 35 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 37 | from util.files import load_config, load_embeddings_from_dir | 37 | from util.files import load_config, load_embeddings_from_dir |
| 38 | from util.ti import load_embeddings | ||
| 38 | 39 | ||
| 39 | 40 | ||
| 40 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}): | |||
| 229 | json.dump(info, f, indent=4) | 230 | json.dump(info, f, indent=4) |
| 230 | 231 | ||
| 231 | 232 | ||
| 232 | def load_embeddings(pipeline, embeddings_dir): | 233 | def load_embeddings_dir(pipeline, embeddings_dir): |
| 233 | added_tokens, added_ids = load_embeddings_from_dir( | 234 | added_tokens, added_ids = load_embeddings_from_dir( |
| 234 | pipeline.tokenizer, | 235 | pipeline.tokenizer, |
| 235 | pipeline.text_encoder.text_model.embeddings, | 236 | pipeline.text_encoder.text_model.embeddings, |
| @@ -258,6 +259,9 @@ def load_lora(pipeline, path): | |||
| 258 | text_encoder_lora_ds = { | 259 | text_encoder_lora_ds = { |
| 259 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k |
| 260 | } | 261 | } |
| 262 | ti_lora_ds = { | ||
| 263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | ||
| 264 | } | ||
| 261 | 265 | ||
| 262 | unet_config = LoraConfig(**lora_config["peft_config"]) | 266 | unet_config = LoraConfig(**lora_config["peft_config"]) |
| 263 | pipeline.unet = LoraModel(unet_config, pipeline.unet) | 267 | pipeline.unet = LoraModel(unet_config, pipeline.unet) |
| @@ -268,6 +272,18 @@ def load_lora(pipeline, path): | |||
| 268 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) | 272 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) |
| 269 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) | 273 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) |
| 270 | 274 | ||
| 275 | tokens = [k for k, _ in ti_lora_ds] | ||
| 276 | token_embeddings = [v for _, v in ti_lora_ds] | ||
| 277 | |||
| 278 | added_tokens, added_ids = load_embeddings( | ||
| 279 | tokenizer=pipeline.tokenizer, | ||
| 280 | embeddings=pipeline.text_encoder.text_model.embeddings, | ||
| 281 | tokens=tokens, | ||
| 282 | token_embeddings=token_embeddings, | ||
| 283 | ) | ||
| 284 | pipeline.text_encoder.text_model.embeddings.persist() | ||
| 285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | ||
| 286 | |||
| 271 | return | 287 | return |
| 272 | 288 | ||
| 273 | 289 | ||
| @@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd): | |||
| 435 | return True | 451 | return True |
| 436 | 452 | ||
| 437 | if elements[0] == 'reload_embeddings': | 453 | if elements[0] == 'reload_embeddings': |
| 438 | load_embeddings(self.pipeline, self.ti_embeddings_dir) | 454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
| 439 | return | 455 | return |
| 440 | 456 | ||
| 441 | try: | 457 | try: |
| @@ -475,7 +491,7 @@ def main(): | |||
| 475 | 491 | ||
| 476 | pipeline = create_pipeline(args.model, dtype) | 492 | pipeline = create_pipeline(args.model, dtype) |
| 477 | 493 | ||
| 478 | load_embeddings(pipeline, args.ti_embeddings_dir) | 494 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
| 479 | load_lora(pipeline, args.lora_embedding) | 495 | load_lora(pipeline, args.lora_embedding) |
| 480 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 481 | 497 | ||
