diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-16 09:16:05 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-16 09:16:05 +0100 |
| commit | d673760fc671d665aadae3b032f8e99f21ab986d (patch) | |
| tree | 7c14a998742b19ddecac6ee25a669892b41c305e /infer.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.gz textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.bz2 textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.zip | |
Integrated WIP UniPC scheduler
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 15 |
1 files changed, 12 insertions, 3 deletions
| @@ -29,6 +29,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
| 29 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
| 30 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 31 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 31 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 32 | from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | ||
| 32 | from util import load_config, load_embeddings_from_dir | 33 | from util import load_config, load_embeddings_from_dir |
| 33 | 34 | ||
| 34 | 35 | ||
| @@ -61,6 +62,7 @@ default_cmds = { | |||
| 61 | "batch_num": 1, | 62 | "batch_num": 1, |
| 62 | "steps": 30, | 63 | "steps": 30, |
| 63 | "guidance_scale": 7.0, | 64 | "guidance_scale": 7.0, |
| 65 | "sag_scale": 0.75, | ||
| 64 | "lora_scale": 0.5, | 66 | "lora_scale": 0.5, |
| 65 | "seed": None, | 67 | "seed": None, |
| 66 | "config": None, | 68 | "config": None, |
| @@ -122,7 +124,7 @@ def create_cmd_parser(): | |||
| 122 | parser.add_argument( | 124 | parser.add_argument( |
| 123 | "--scheduler", | 125 | "--scheduler", |
| 124 | type=str, | 126 | type=str, |
| 125 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | 127 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], |
| 126 | ) | 128 | ) |
| 127 | parser.add_argument( | 129 | parser.add_argument( |
| 128 | "--template", | 130 | "--template", |
| @@ -175,6 +177,10 @@ def create_cmd_parser(): | |||
| 175 | type=float, | 177 | type=float, |
| 176 | ) | 178 | ) |
| 177 | parser.add_argument( | 179 | parser.add_argument( |
| 180 | "--sag_scale", | ||
| 181 | type=float, | ||
| 182 | ) | ||
| 183 | parser.add_argument( | ||
| 178 | "--lora_scale", | 184 | "--lora_scale", |
| 179 | type=float, | 185 | type=float, |
| 180 | ) | 186 | ) |
| @@ -304,6 +310,8 @@ def generate(output_dir: Path, pipeline, args): | |||
| 304 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | 310 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) |
| 305 | elif args.scheduler == "kdpm2_a": | 311 | elif args.scheduler == "kdpm2_a": |
| 306 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 312 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 313 | elif args.scheduler == "unipc": | ||
| 314 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | ||
| 307 | 315 | ||
| 308 | for i in range(args.batch_num): | 316 | for i in range(args.batch_num): |
| 309 | pipeline.set_progress_bar_config( | 317 | pipeline.set_progress_bar_config( |
| @@ -322,10 +330,11 @@ def generate(output_dir: Path, pipeline, args): | |||
| 322 | num_images_per_prompt=args.batch_size, | 330 | num_images_per_prompt=args.batch_size, |
| 323 | num_inference_steps=args.steps, | 331 | num_inference_steps=args.steps, |
| 324 | guidance_scale=args.guidance_scale, | 332 | guidance_scale=args.guidance_scale, |
| 333 | sag_scale=args.sag_scale, | ||
| 325 | generator=generator, | 334 | generator=generator, |
| 326 | image=init_image, | 335 | image=init_image, |
| 327 | strength=args.image_noise, | 336 | strength=args.image_noise, |
| 328 | cross_attention_kwargs={"scale": args.lora_scale}, | 337 | # cross_attention_kwargs={"scale": args.lora_scale}, |
| 329 | ).images | 338 | ).images |
| 330 | 339 | ||
| 331 | for j, image in enumerate(images): | 340 | for j, image in enumerate(images): |
| @@ -408,7 +417,7 @@ def main(): | |||
| 408 | pipeline = create_pipeline(args.model, dtype) | 417 | pipeline = create_pipeline(args.model, dtype) |
| 409 | 418 | ||
| 410 | load_embeddings(pipeline, args.ti_embeddings_dir) | 419 | load_embeddings(pipeline, args.ti_embeddings_dir) |
| 411 | pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 420 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 412 | 421 | ||
| 413 | cmd_parser = create_cmd_parser() | 422 | cmd_parser = create_cmd_parser() |
| 414 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) | 423 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) |
