diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
| commit | 5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch) | |
| tree | a3461a4f1a04fba52ec8fde8b7b07095c7422d85 /infer.py | |
| parent | Added custom SD pipeline + euler_a scheduler (diff) | |
| download | textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2 textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip | |
Made inference script interactive
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 154 |
1 files changed, 109 insertions, 45 deletions
| @@ -1,18 +1,21 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import sys | ||
| 5 | import shlex | ||
| 6 | import cmd | ||
| 4 | from pathlib import Path | 7 | from pathlib import Path |
| 5 | from torch import autocast | 8 | from torch import autocast |
| 6 | import torch | 9 | import torch |
| 7 | import json | 10 | import json |
| 8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
| 9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
| 10 | from slugify import slugify | 13 | from slugify import slugify |
| 11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
| 12 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 13 | 16 | ||
| 14 | 17 | ||
| 15 | def parse_args(): | 18 | def create_args_parser(): |
| 16 | parser = argparse.ArgumentParser( | 19 | parser = argparse.ArgumentParser( |
| 17 | description="Simple example of a training script." | 20 | description="Simple example of a training script." |
| 18 | ) | 21 | ) |
| @@ -22,6 +25,30 @@ def parse_args(): | |||
| 22 | default=None, | 25 | default=None, |
| 23 | ) | 26 | ) |
| 24 | parser.add_argument( | 27 | parser.add_argument( |
| 28 | "--scheduler", | ||
| 29 | type=str, | ||
| 30 | choices=["plms", "ddim", "klms", "euler_a"], | ||
| 31 | default="euler_a", | ||
| 32 | ) | ||
| 33 | parser.add_argument( | ||
| 34 | "--output_dir", | ||
| 35 | type=str, | ||
| 36 | default="output/inference", | ||
| 37 | ) | ||
| 38 | parser.add_argument( | ||
| 39 | "--config", | ||
| 40 | type=str, | ||
| 41 | default=None, | ||
| 42 | ) | ||
| 43 | |||
| 44 | return parser | ||
| 45 | |||
| 46 | |||
| 47 | def create_cmd_parser(): | ||
| 48 | parser = argparse.ArgumentParser( | ||
| 49 | description="Simple example of a training script." | ||
| 50 | ) | ||
| 51 | parser.add_argument( | ||
| 25 | "--prompt", | 52 | "--prompt", |
| 26 | type=str, | 53 | type=str, |
| 27 | default=None, | 54 | default=None, |
| @@ -49,28 +76,17 @@ def parse_args(): | |||
| 49 | parser.add_argument( | 76 | parser.add_argument( |
| 50 | "--batch_num", | 77 | "--batch_num", |
| 51 | type=int, | 78 | type=int, |
| 52 | default=50, | 79 | default=1, |
| 53 | ) | 80 | ) |
| 54 | parser.add_argument( | 81 | parser.add_argument( |
| 55 | "--steps", | 82 | "--steps", |
| 56 | type=int, | 83 | type=int, |
| 57 | default=120, | 84 | default=70, |
| 58 | ) | ||
| 59 | parser.add_argument( | ||
| 60 | "--scheduler", | ||
| 61 | type=str, | ||
| 62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
| 63 | default="euler_a", | ||
| 64 | ) | 85 | ) |
| 65 | parser.add_argument( | 86 | parser.add_argument( |
| 66 | "--guidance_scale", | 87 | "--guidance_scale", |
| 67 | type=int, | 88 | type=int, |
| 68 | default=7.5, | 89 | default=7, |
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 71 | "--clip_guidance_scale", | ||
| 72 | type=int, | ||
| 73 | default=100, | ||
| 74 | ) | 90 | ) |
| 75 | parser.add_argument( | 91 | parser.add_argument( |
| 76 | "--seed", | 92 | "--seed", |
| @@ -78,21 +94,21 @@ def parse_args(): | |||
| 78 | default=torch.random.seed(), | 94 | default=torch.random.seed(), |
| 79 | ) | 95 | ) |
| 80 | parser.add_argument( | 96 | parser.add_argument( |
| 81 | "--output_dir", | ||
| 82 | type=str, | ||
| 83 | default="output/inference", | ||
| 84 | ) | ||
| 85 | parser.add_argument( | ||
| 86 | "--config", | 97 | "--config", |
| 87 | type=str, | 98 | type=str, |
| 88 | default=None, | 99 | default=None, |
| 89 | ) | 100 | ) |
| 90 | 101 | ||
| 91 | args = parser.parse_args() | 102 | return parser |
| 103 | |||
| 104 | |||
| 105 | def run_parser(parser, input=None): | ||
| 106 | args = parser.parse_known_args(input)[0] | ||
| 107 | |||
| 92 | if args.config is not None: | 108 | if args.config is not None: |
| 93 | with open(args.config, 'rt') as f: | 109 | with open(args.config, 'rt') as f: |
| 94 | args = parser.parse_args( | 110 | args = parser.parse_known_args( |
| 95 | namespace=argparse.Namespace(**json.load(f)["args"])) | 111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] |
| 96 | 112 | ||
| 97 | return args | 113 | return args |
| 98 | 114 | ||
| @@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}): | |||
| 104 | json.dump(info, f, indent=4) | 120 | json.dump(info, f, indent=4) |
| 105 | 121 | ||
| 106 | 122 | ||
| 107 | def gen(args, output_dir): | 123 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): |
| 108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 124 | print("Loading Stable Diffusion pipeline...") |
| 109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | ||
| 110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
| 111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | ||
| 112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | ||
| 113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( | ||
| 114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
| 115 | 125 | ||
| 116 | if args.scheduler == "plms": | 126 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
| 127 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | ||
| 128 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | ||
| 129 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | ||
| 130 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
| 131 | |||
| 132 | if scheduler == "plms": | ||
| 117 | scheduler = PNDMScheduler( | 133 | scheduler = PNDMScheduler( |
| 118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 134 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| 119 | ) | 135 | ) |
| 120 | elif args.scheduler == "klms": | 136 | elif scheduler == "klms": |
| 121 | scheduler = LMSDiscreteScheduler( | 137 | scheduler = LMSDiscreteScheduler( |
| 122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 138 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 123 | ) | 139 | ) |
| 124 | elif args.scheduler == "ddim": | 140 | elif scheduler == "ddim": |
| 125 | scheduler = DDIMScheduler( | 141 | scheduler = DDIMScheduler( |
| 126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 142 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
| 127 | ) | 143 | ) |
| @@ -135,13 +151,24 @@ def gen(args, output_dir): | |||
| 135 | vae=vae, | 151 | vae=vae, |
| 136 | unet=unet, | 152 | unet=unet, |
| 137 | tokenizer=tokenizer, | 153 | tokenizer=tokenizer, |
| 138 | clip_model=clip_model, | ||
| 139 | scheduler=scheduler, | 154 | scheduler=scheduler, |
| 140 | feature_extractor=feature_extractor | 155 | feature_extractor=feature_extractor |
| 141 | ) | 156 | ) |
| 142 | pipeline.enable_attention_slicing() | 157 | pipeline.enable_attention_slicing() |
| 143 | pipeline.to("cuda") | 158 | pipeline.to("cuda") |
| 144 | 159 | ||
| 160 | print("Pipeline loaded.") | ||
| 161 | |||
| 162 | return pipeline | ||
| 163 | |||
| 164 | |||
| 165 | def generate(output_dir, pipeline, args): | ||
| 166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
| 168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 169 | |||
| 170 | save_args(output_dir, args) | ||
| 171 | |||
| 145 | with autocast("cuda"): | 172 | with autocast("cuda"): |
| 146 | for i in range(args.batch_num): | 173 | for i in range(args.batch_num): |
| 147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
| @@ -152,7 +179,6 @@ def gen(args, output_dir): | |||
| 152 | negative_prompt=args.negative_prompt, | 179 | negative_prompt=args.negative_prompt, |
| 153 | num_inference_steps=args.steps, | 180 | num_inference_steps=args.steps, |
| 154 | guidance_scale=args.guidance_scale, | 181 | guidance_scale=args.guidance_scale, |
| 155 | clip_guidance_scale=args.clip_guidance_scale, | ||
| 156 | generator=generator, | 182 | generator=generator, |
| 157 | ).images | 183 | ).images |
| 158 | 184 | ||
| @@ -160,18 +186,56 @@ def gen(args, output_dir): | |||
| 160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
| 161 | 187 | ||
| 162 | 188 | ||
| 163 | def main(): | 189 | class CmdParse(cmd.Cmd): |
| 164 | args = parse_args() | 190 | prompt = 'dream> ' |
| 191 | commands = [] | ||
| 165 | 192 | ||
| 166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 193 | def __init__(self, output_dir, pipeline, parser): |
| 167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 194 | super().__init__() |
| 168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 169 | 195 | ||
| 170 | save_args(output_dir, args) | 196 | self.output_dir = output_dir |
| 197 | self.pipeline = pipeline | ||
| 198 | self.parser = parser | ||
| 199 | |||
| 200 | def default(self, line): | ||
| 201 | line = line.replace("'", "\\'") | ||
| 202 | |||
| 203 | try: | ||
| 204 | elements = shlex.split(line) | ||
| 205 | except ValueError as e: | ||
| 206 | print(str(e)) | ||
| 207 | |||
| 208 | if elements[0] == 'q': | ||
| 209 | return True | ||
| 210 | |||
| 211 | try: | ||
| 212 | args = run_parser(self.parser, elements) | ||
| 213 | except SystemExit: | ||
| 214 | self.parser.print_help() | ||
| 215 | |||
| 216 | if len(args.prompt) == 0: | ||
| 217 | print('Try again with a prompt!') | ||
| 218 | |||
| 219 | try: | ||
| 220 | generate(self.output_dir, self.pipeline, args) | ||
| 221 | except KeyboardInterrupt: | ||
| 222 | print('Generation cancelled.') | ||
| 223 | |||
| 224 | def do_exit(self, line): | ||
| 225 | return True | ||
| 226 | |||
| 227 | |||
| 228 | def main(): | ||
| 229 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | ||
| 171 | 230 | ||
| 172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 231 | args_parser = create_args_parser() |
| 232 | args = run_parser(args_parser) | ||
| 233 | output_dir = Path(args.output_dir) | ||
| 173 | 234 | ||
| 174 | gen(args, output_dir) | 235 | pipeline = create_pipeline(args.model, args.scheduler) |
| 236 | cmd_parser = create_cmd_parser() | ||
| 237 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | ||
| 238 | cmd_prompt.cmdloop() | ||
| 175 | 239 | ||
| 176 | 240 | ||
| 177 | if __name__ == "__main__": | 241 | if __name__ == "__main__": |
