diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 | 
| commit | 49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch) | |
| tree | 3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /infer.py | |
| parent | Fix seed, better progress bar, fix euler_a for batch size > 1 (diff) | |
| download | textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2 textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip | |
WIP: img2img
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 90 | 
1 files changed, 67 insertions, 23 deletions
| @@ -8,13 +8,47 @@ from pathlib import Path | |||
| 8 | from torch import autocast | 8 | from torch import autocast | 
| 9 | import torch | 9 | import torch | 
| 10 | import json | 10 | import json | 
| 11 | from PIL import Image | ||
| 11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 
| 12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 
| 13 | from slugify import slugify | 14 | from slugify import slugify | 
| 14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 15 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler | 
| 16 | 17 | ||
| 17 | 18 | ||
| 19 | default_args = { | ||
| 20 | "model": None, | ||
| 21 | "scheduler": "euler_a", | ||
| 22 | "output_dir": "output/inference", | ||
| 23 | "config": None, | ||
| 24 | } | ||
| 25 | |||
| 26 | |||
| 27 | default_cmds = { | ||
| 28 | "prompt": None, | ||
| 29 | "negative_prompt": None, | ||
| 30 | "image": None, | ||
| 31 | "image_strength": .7, | ||
| 32 | "width": 512, | ||
| 33 | "height": 512, | ||
| 34 | "batch_size": 1, | ||
| 35 | "batch_num": 1, | ||
| 36 | "steps": 50, | ||
| 37 | "guidance_scale": 7.0, | ||
| 38 | "seed": None, | ||
| 39 | "config": None, | ||
| 40 | } | ||
| 41 | |||
| 42 | |||
| 43 | def merge_dicts(d1, *args): | ||
| 44 | d1 = d1.copy() | ||
| 45 | |||
| 46 | for d in args: | ||
| 47 | d1.update({k: v for (k, v) in d.items() if v is not None}) | ||
| 48 | |||
| 49 | return d1 | ||
| 50 | |||
| 51 | |||
| 18 | def create_args_parser(): | 52 | def create_args_parser(): | 
| 19 | parser = argparse.ArgumentParser( | 53 | parser = argparse.ArgumentParser( | 
| 20 | description="Simple example of a training script." | 54 | description="Simple example of a training script." | 
| @@ -22,23 +56,19 @@ def create_args_parser(): | |||
| 22 | parser.add_argument( | 56 | parser.add_argument( | 
| 23 | "--model", | 57 | "--model", | 
| 24 | type=str, | 58 | type=str, | 
| 25 | default=None, | ||
| 26 | ) | 59 | ) | 
| 27 | parser.add_argument( | 60 | parser.add_argument( | 
| 28 | "--scheduler", | 61 | "--scheduler", | 
| 29 | type=str, | 62 | type=str, | 
| 30 | choices=["plms", "ddim", "klms", "euler_a"], | 63 | choices=["plms", "ddim", "klms", "euler_a"], | 
| 31 | default="euler_a", | ||
| 32 | ) | 64 | ) | 
| 33 | parser.add_argument( | 65 | parser.add_argument( | 
| 34 | "--output_dir", | 66 | "--output_dir", | 
| 35 | type=str, | 67 | type=str, | 
| 36 | default="output/inference", | ||
| 37 | ) | 68 | ) | 
| 38 | parser.add_argument( | 69 | parser.add_argument( | 
| 39 | "--config", | 70 | "--config", | 
| 40 | type=str, | 71 | type=str, | 
| 41 | default=None, | ||
| 42 | ) | 72 | ) | 
| 43 | 73 | ||
| 44 | return parser | 74 | return parser | 
| @@ -51,66 +81,69 @@ def create_cmd_parser(): | |||
| 51 | parser.add_argument( | 81 | parser.add_argument( | 
| 52 | "--prompt", | 82 | "--prompt", | 
| 53 | type=str, | 83 | type=str, | 
| 54 | default=None, | ||
| 55 | ) | 84 | ) | 
| 56 | parser.add_argument( | 85 | parser.add_argument( | 
| 57 | "--negative_prompt", | 86 | "--negative_prompt", | 
| 58 | type=str, | 87 | type=str, | 
| 59 | default=None, | 88 | ) | 
| 89 | parser.add_argument( | ||
| 90 | "--image", | ||
| 91 | type=str, | ||
| 92 | ) | ||
| 93 | parser.add_argument( | ||
| 94 | "--image_strength", | ||
| 95 | type=float, | ||
| 60 | ) | 96 | ) | 
| 61 | parser.add_argument( | 97 | parser.add_argument( | 
| 62 | "--width", | 98 | "--width", | 
| 63 | type=int, | 99 | type=int, | 
| 64 | default=512, | ||
| 65 | ) | 100 | ) | 
| 66 | parser.add_argument( | 101 | parser.add_argument( | 
| 67 | "--height", | 102 | "--height", | 
| 68 | type=int, | 103 | type=int, | 
| 69 | default=512, | ||
| 70 | ) | 104 | ) | 
| 71 | parser.add_argument( | 105 | parser.add_argument( | 
| 72 | "--batch_size", | 106 | "--batch_size", | 
| 73 | type=int, | 107 | type=int, | 
| 74 | default=1, | ||
| 75 | ) | 108 | ) | 
| 76 | parser.add_argument( | 109 | parser.add_argument( | 
| 77 | "--batch_num", | 110 | "--batch_num", | 
| 78 | type=int, | 111 | type=int, | 
| 79 | default=1, | ||
| 80 | ) | 112 | ) | 
| 81 | parser.add_argument( | 113 | parser.add_argument( | 
| 82 | "--steps", | 114 | "--steps", | 
| 83 | type=int, | 115 | type=int, | 
| 84 | default=70, | ||
| 85 | ) | 116 | ) | 
| 86 | parser.add_argument( | 117 | parser.add_argument( | 
| 87 | "--guidance_scale", | 118 | "--guidance_scale", | 
| 88 | type=int, | 119 | type=float, | 
| 89 | default=7, | ||
| 90 | ) | 120 | ) | 
| 91 | parser.add_argument( | 121 | parser.add_argument( | 
| 92 | "--seed", | 122 | "--seed", | 
| 93 | type=int, | 123 | type=int, | 
| 94 | default=None, | ||
| 95 | ) | 124 | ) | 
| 96 | parser.add_argument( | 125 | parser.add_argument( | 
| 97 | "--config", | 126 | "--config", | 
| 98 | type=str, | 127 | type=str, | 
| 99 | default=None, | ||
| 100 | ) | 128 | ) | 
| 101 | 129 | ||
| 102 | return parser | 130 | return parser | 
| 103 | 131 | ||
| 104 | 132 | ||
| 105 | def run_parser(parser, input=None): | 133 | def run_parser(parser, defaults, input=None): | 
| 106 | args = parser.parse_known_args(input)[0] | 134 | args = parser.parse_known_args(input)[0] | 
| 135 | conf_args = argparse.Namespace() | ||
| 107 | 136 | ||
| 108 | if args.config is not None: | 137 | if args.config is not None: | 
| 109 | with open(args.config, 'rt') as f: | 138 | with open(args.config, 'rt') as f: | 
| 110 | args = parser.parse_known_args( | 139 | conf_args = parser.parse_known_args( | 
| 111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] | 140 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] | 
| 112 | 141 | ||
| 113 | return args | 142 | res = defaults.copy() | 
| 143 | for dict in [vars(conf_args), vars(args)]: | ||
| 144 | res.update({k: v for (k, v) in dict.items() if v is not None}) | ||
| 145 | |||
| 146 | return argparse.Namespace(**res) | ||
| 114 | 147 | ||
| 115 | 148 | ||
| 116 | def save_args(basepath, args, extra={}): | 149 | def save_args(basepath, args, extra={}): | 
| @@ -146,7 +179,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
| 146 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 
| 147 | ) | 180 | ) | 
| 148 | 181 | ||
| 149 | pipeline = CLIPGuidedStableDiffusion( | 182 | pipeline = VlpnStableDiffusion( | 
| 150 | text_encoder=text_encoder, | 183 | text_encoder=text_encoder, | 
| 151 | vae=vae, | 184 | vae=vae, | 
| 152 | unet=unet, | 185 | unet=unet, | 
| @@ -154,7 +187,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
| 154 | scheduler=scheduler, | 187 | scheduler=scheduler, | 
| 155 | feature_extractor=feature_extractor | 188 | feature_extractor=feature_extractor | 
| 156 | ) | 189 | ) | 
| 157 | pipeline.enable_attention_slicing() | 190 | # pipeline.enable_attention_slicing() | 
| 158 | pipeline.to("cuda") | 191 | pipeline.to("cuda") | 
| 159 | 192 | ||
| 160 | print("Pipeline loaded.") | 193 | print("Pipeline loaded.") | 
| @@ -171,6 +204,13 @@ def generate(output_dir, pipeline, args): | |||
| 171 | 204 | ||
| 172 | save_args(output_dir, args) | 205 | save_args(output_dir, args) | 
| 173 | 206 | ||
| 207 | if args.image: | ||
| 208 | init_image = Image.open(args.image) | ||
| 209 | if not init_image.mode == "RGB": | ||
| 210 | init_image = init_image.convert("RGB") | ||
| 211 | else: | ||
| 212 | init_image = None | ||
| 213 | |||
| 174 | with autocast("cuda"): | 214 | with autocast("cuda"): | 
| 175 | for i in range(args.batch_num): | 215 | for i in range(args.batch_num): | 
| 176 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") | 216 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") | 
| @@ -184,11 +224,15 @@ def generate(output_dir, pipeline, args): | |||
| 184 | num_inference_steps=args.steps, | 224 | num_inference_steps=args.steps, | 
| 185 | guidance_scale=args.guidance_scale, | 225 | guidance_scale=args.guidance_scale, | 
| 186 | generator=generator, | 226 | generator=generator, | 
| 227 | latents=init_image, | ||
| 187 | ).images | 228 | ).images | 
| 188 | 229 | ||
| 189 | for j, image in enumerate(images): | 230 | for j, image in enumerate(images): | 
| 190 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 231 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 
| 191 | 232 | ||
| 233 | if torch.cuda.is_available(): | ||
| 234 | torch.cuda.empty_cache() | ||
| 235 | |||
| 192 | 236 | ||
| 193 | class CmdParse(cmd.Cmd): | 237 | class CmdParse(cmd.Cmd): | 
| 194 | prompt = 'dream> ' | 238 | prompt = 'dream> ' | 
| @@ -213,7 +257,7 @@ class CmdParse(cmd.Cmd): | |||
| 213 | return True | 257 | return True | 
| 214 | 258 | ||
| 215 | try: | 259 | try: | 
| 216 | args = run_parser(self.parser, elements) | 260 | args = run_parser(self.parser, default_cmds, elements) | 
| 217 | except SystemExit: | 261 | except SystemExit: | 
| 218 | self.parser.print_help() | 262 | self.parser.print_help() | 
| 219 | 263 | ||
| @@ -233,7 +277,7 @@ def main(): | |||
| 233 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | 277 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | 
| 234 | 278 | ||
| 235 | args_parser = create_args_parser() | 279 | args_parser = create_args_parser() | 
| 236 | args = run_parser(args_parser) | 280 | args = run_parser(args_parser, default_args) | 
| 237 | output_dir = Path(args.output_dir) | 281 | output_dir = Path(args.output_dir) | 
| 238 | 282 | ||
| 239 | pipeline = create_pipeline(args.model, args.scheduler) | 283 | pipeline = create_pipeline(args.model, args.scheduler) | 
