diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 22 |
1 files changed, 19 insertions, 3 deletions
@@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings | |||
35 | from models.clip.tokenizer import MultiCLIPTokenizer | 35 | from models.clip.tokenizer import MultiCLIPTokenizer |
36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
37 | from util.files import load_config, load_embeddings_from_dir | 37 | from util.files import load_config, load_embeddings_from_dir |
38 | from util.ti import load_embeddings | ||
38 | 39 | ||
39 | 40 | ||
40 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}): | |||
229 | json.dump(info, f, indent=4) | 230 | json.dump(info, f, indent=4) |
230 | 231 | ||
231 | 232 | ||
232 | def load_embeddings(pipeline, embeddings_dir): | 233 | def load_embeddings_dir(pipeline, embeddings_dir): |
233 | added_tokens, added_ids = load_embeddings_from_dir( | 234 | added_tokens, added_ids = load_embeddings_from_dir( |
234 | pipeline.tokenizer, | 235 | pipeline.tokenizer, |
235 | pipeline.text_encoder.text_model.embeddings, | 236 | pipeline.text_encoder.text_model.embeddings, |
@@ -258,6 +259,9 @@ def load_lora(pipeline, path): | |||
258 | text_encoder_lora_ds = { | 259 | text_encoder_lora_ds = { |
259 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k |
260 | } | 261 | } |
262 | ti_lora_ds = { | ||
263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | ||
264 | } | ||
261 | 265 | ||
262 | unet_config = LoraConfig(**lora_config["peft_config"]) | 266 | unet_config = LoraConfig(**lora_config["peft_config"]) |
263 | pipeline.unet = LoraModel(unet_config, pipeline.unet) | 267 | pipeline.unet = LoraModel(unet_config, pipeline.unet) |
@@ -268,6 +272,18 @@ def load_lora(pipeline, path): | |||
268 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) | 272 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) |
269 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) | 273 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) |
270 | 274 | ||
275 | tokens = [k for k, _ in ti_lora_ds] | ||
276 | token_embeddings = [v for _, v in ti_lora_ds] | ||
277 | |||
278 | added_tokens, added_ids = load_embeddings( | ||
279 | tokenizer=pipeline.tokenizer, | ||
280 | embeddings=pipeline.text_encoder.text_model.embeddings, | ||
281 | tokens=tokens, | ||
282 | token_embeddings=token_embeddings, | ||
283 | ) | ||
284 | pipeline.text_encoder.text_model.embeddings.persist() | ||
285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | ||
286 | |||
271 | return | 287 | return |
272 | 288 | ||
273 | 289 | ||
@@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd): | |||
435 | return True | 451 | return True |
436 | 452 | ||
437 | if elements[0] == 'reload_embeddings': | 453 | if elements[0] == 'reload_embeddings': |
438 | load_embeddings(self.pipeline, self.ti_embeddings_dir) | 454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
439 | return | 455 | return |
440 | 456 | ||
441 | try: | 457 | try: |
@@ -475,7 +491,7 @@ def main(): | |||
475 | 491 | ||
476 | pipeline = create_pipeline(args.model, dtype) | 492 | pipeline = create_pipeline(args.model, dtype) |
477 | 493 | ||
478 | load_embeddings(pipeline, args.ti_embeddings_dir) | 494 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
479 | load_lora(pipeline, args.lora_embedding) | 495 | load_lora(pipeline, args.lora_embedding) |
480 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
481 | 497 | ||