From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- infer.py | 124 ++++++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 41 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 7346de9..3b3b595 100644 --- a/infer.py +++ b/infer.py @@ -24,7 +24,7 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DEISMultistepScheduler, - UniPCMultistepScheduler + UniPCMultistepScheduler, ) from peft import LoraConfig, LoraModel, set_peft_model_state_dict from safetensors.torch import load_file @@ -61,7 +61,7 @@ default_cmds = { "negative_prompt": None, "shuffle": False, "image": None, - "image_noise": .7, + "image_noise": 0.7, "width": 768, "height": 768, "batch_size": 1, @@ -69,7 +69,6 @@ default_cmds = { "steps": 30, "guidance_scale": 7.0, "sag_scale": 0, - "brightness_offset": 0, "seed": None, "config": None, } @@ -85,9 +84,7 @@ def merge_dicts(d1, *args): def create_args_parser(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--model", type=str, @@ -118,9 +115,7 @@ def create_args_parser(): def create_cmd_parser(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--project", type=str, @@ -130,13 +125,34 @@ def create_cmd_parser(): parser.add_argument( "--scheduler", type=str, - choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], + choices=[ + "plms", + "ddim", + "klms", + "dpmsm", + "dpmss", + "euler_a", + "kdpm2", + "kdpm2_a", + "deis", + "unipc", + ], ) parser.add_argument( "--subscheduler", type=str, default=None, - choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], + choices=[ + "plms", + "ddim", + "klms", + "dpmsm", + "dpmss", + "euler_a", + "kdpm2", + "kdpm2_a", + "deis", + ], ) parser.add_argument( "--template", @@ -192,10 +208,6 @@ def create_cmd_parser(): "--sag_scale", type=float, ) - parser.add_argument( - "--brightness_offset", - type=float, - ) parser.add_argument( "--seed", type=int, @@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None): if args.config is not None: conf_args = load_config(args.config) - conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] + conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ + 0 + ] res = defaults.copy() for dict in [vars(conf_args), vars(args)]: @@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir): added_tokens, added_ids = load_embeddings_from_dir( pipeline.tokenizer, pipeline.text_encoder.text_model.embeddings, - Path(embeddings_dir) + Path(embeddings_dir), ) pipeline.text_encoder.text_model.embeddings.persist() - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) def load_lora(pipeline, path): @@ -255,9 +271,13 @@ def load_lora(pipeline, path): return lora_checkpoint_sd = load_file(path / tensor_files[0]) - unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} + unet_lora_ds = { + k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k + } text_encoder_lora_ds = { - k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k + k.replace("text_encoder_", ""): v + for k, v in lora_checkpoint_sd.items() + if "text_encoder_" in k } ti_lora_ds = { k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k @@ -282,7 +302,9 @@ def load_lora(pipeline, path): token_embeddings=token_embeddings, ) pipeline.text_encoder.text_model.embeddings.persist() - print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" + ) return @@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) solver_p=create_scheduler(config, subscheduler), ) else: - raise ValueError(f"Unknown scheduler \"{scheduler}\"") + raise ValueError(f'Unknown scheduler "{scheduler}"') def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) - text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) - vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) - unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) - scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) + tokenizer = MultiCLIPTokenizer.from_pretrained( + model, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = CLIPTextModel.from_pretrained( + model, subfolder="text_encoder", torch_dtype=dtype + ) + vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) + unet = UNet2DConditionModel.from_pretrained( + model, subfolder="unet", torch_dtype=dtype + ) + scheduler = DDIMScheduler.from_pretrained( + model, subfolder="scheduler", torch_dtype=dtype + ) patch_managed_embeddings(text_encoder) @@ -347,7 +377,9 @@ def create_pipeline(model, dtype): def shuffle_prompts(prompts: list[str]) -> list[str]: - return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] + return [ + keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts + ] @torch.inference_mode() @@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args): else: init_image = None - pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) + pipeline.scheduler = create_scheduler( + pipeline.scheduler.config, args.scheduler, args.subscheduler + ) for i in range(args.batch_num): pipeline.set_progress_bar_config( - desc=f"Batch {i + 1} of {args.batch_num}", - dynamic_ncols=True + desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True ) seed = args.seed + i @@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args): generator=generator, image=init_image, strength=args.image_noise, - brightness_offset=args.brightness_offset, ).images for j, image in enumerate(images): @@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args): image.save(dir / f"{basename}.png") image.save(dir / f"{basename}.jpg", quality=85) - with open(dir / f"{basename}.txt", 'w') as f: + with open(dir / f"{basename}.txt", "w") as f: f.write(prompt[j % len(args.prompt)]) if torch.cuda.is_available(): @@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args): class CmdParse(cmd.Cmd): - prompt = 'dream> ' + prompt = "dream> " commands = [] - def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): + def __init__( + self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser + ): super().__init__() self.output_dir = output_dir @@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd): print(str(e)) return - if elements[0] == 'q': + if elements[0] == "q": return True - if elements[0] == 'reload_embeddings': + if elements[0] == "reload_embeddings": load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) return @@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd): args = run_parser(self.parser, default_cmds, elements) if len(args.prompt) == 0: - print('Try again with a prompt!') + print("Try again with a prompt!") return except SystemExit: traceback.print_exc() @@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd): try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: - print('Generation cancelled.') + print("Generation cancelled.") except Exception as e: traceback.print_exc() return @@ -487,7 +521,9 @@ def main(): args = run_parser(args_parser, default_args) output_dir = Path(args.output_dir) - dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] + dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ + args.precision + ] pipeline = create_pipeline(args.model, dtype) @@ -496,7 +532,13 @@ def main(): # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) cmd_parser = create_cmd_parser() - cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) + cmd_prompt = CmdParse( + output_dir, + args.ti_embeddings_dir, + args.lora_embeddings_dir, + pipeline, + cmd_parser, + ) cmd_prompt.cmdloop() -- cgit v1.2.3-54-g00ecf