summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py90
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py (renamed from pipelines/stable_diffusion/clip_guided_stable_diffusion.py)80
2 files changed, 131 insertions, 39 deletions
diff --git a/infer.py b/infer.py
index d917239..b440cb6 100644
--- a/infer.py
+++ b/infer.py
@@ -8,13 +8,47 @@ from pathlib import Path
8from torch import autocast 8from torch import autocast
9import torch 9import torch
10import json 10import json
11from PIL import Image
11from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 12from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
12from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 13from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
13from slugify import slugify 14from slugify import slugify
14from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
15from schedulers.scheduling_euler_a import EulerAScheduler 16from schedulers.scheduling_euler_a import EulerAScheduler
16 17
17 18
19default_args = {
20 "model": None,
21 "scheduler": "euler_a",
22 "output_dir": "output/inference",
23 "config": None,
24}
25
26
27default_cmds = {
28 "prompt": None,
29 "negative_prompt": None,
30 "image": None,
31 "image_strength": .7,
32 "width": 512,
33 "height": 512,
34 "batch_size": 1,
35 "batch_num": 1,
36 "steps": 50,
37 "guidance_scale": 7.0,
38 "seed": None,
39 "config": None,
40}
41
42
43def merge_dicts(d1, *args):
44 d1 = d1.copy()
45
46 for d in args:
47 d1.update({k: v for (k, v) in d.items() if v is not None})
48
49 return d1
50
51
18def create_args_parser(): 52def create_args_parser():
19 parser = argparse.ArgumentParser( 53 parser = argparse.ArgumentParser(
20 description="Simple example of a training script." 54 description="Simple example of a training script."
@@ -22,23 +56,19 @@ def create_args_parser():
22 parser.add_argument( 56 parser.add_argument(
23 "--model", 57 "--model",
24 type=str, 58 type=str,
25 default=None,
26 ) 59 )
27 parser.add_argument( 60 parser.add_argument(
28 "--scheduler", 61 "--scheduler",
29 type=str, 62 type=str,
30 choices=["plms", "ddim", "klms", "euler_a"], 63 choices=["plms", "ddim", "klms", "euler_a"],
31 default="euler_a",
32 ) 64 )
33 parser.add_argument( 65 parser.add_argument(
34 "--output_dir", 66 "--output_dir",
35 type=str, 67 type=str,
36 default="output/inference",
37 ) 68 )
38 parser.add_argument( 69 parser.add_argument(
39 "--config", 70 "--config",
40 type=str, 71 type=str,
41 default=None,
42 ) 72 )
43 73
44 return parser 74 return parser
@@ -51,66 +81,69 @@ def create_cmd_parser():
51 parser.add_argument( 81 parser.add_argument(
52 "--prompt", 82 "--prompt",
53 type=str, 83 type=str,
54 default=None,
55 ) 84 )
56 parser.add_argument( 85 parser.add_argument(
57 "--negative_prompt", 86 "--negative_prompt",
58 type=str, 87 type=str,
59 default=None, 88 )
89 parser.add_argument(
90 "--image",
91 type=str,
92 )
93 parser.add_argument(
94 "--image_strength",
95 type=float,
60 ) 96 )
61 parser.add_argument( 97 parser.add_argument(
62 "--width", 98 "--width",
63 type=int, 99 type=int,
64 default=512,
65 ) 100 )
66 parser.add_argument( 101 parser.add_argument(
67 "--height", 102 "--height",
68 type=int, 103 type=int,
69 default=512,
70 ) 104 )
71 parser.add_argument( 105 parser.add_argument(
72 "--batch_size", 106 "--batch_size",
73 type=int, 107 type=int,
74 default=1,
75 ) 108 )
76 parser.add_argument( 109 parser.add_argument(
77 "--batch_num", 110 "--batch_num",
78 type=int, 111 type=int,
79 default=1,
80 ) 112 )
81 parser.add_argument( 113 parser.add_argument(
82 "--steps", 114 "--steps",
83 type=int, 115 type=int,
84 default=70,
85 ) 116 )
86 parser.add_argument( 117 parser.add_argument(
87 "--guidance_scale", 118 "--guidance_scale",
88 type=int, 119 type=float,
89 default=7,
90 ) 120 )
91 parser.add_argument( 121 parser.add_argument(
92 "--seed", 122 "--seed",
93 type=int, 123 type=int,
94 default=None,
95 ) 124 )
96 parser.add_argument( 125 parser.add_argument(
97 "--config", 126 "--config",
98 type=str, 127 type=str,
99 default=None,
100 ) 128 )
101 129
102 return parser 130 return parser
103 131
104 132
105def run_parser(parser, input=None): 133def run_parser(parser, defaults, input=None):
106 args = parser.parse_known_args(input)[0] 134 args = parser.parse_known_args(input)[0]
135 conf_args = argparse.Namespace()
107 136
108 if args.config is not None: 137 if args.config is not None:
109 with open(args.config, 'rt') as f: 138 with open(args.config, 'rt') as f:
110 args = parser.parse_known_args( 139 conf_args = parser.parse_known_args(
111 namespace=argparse.Namespace(**json.load(f)["args"]))[0] 140 namespace=argparse.Namespace(**json.load(f)["args"]))[0]
112 141
113 return args 142 res = defaults.copy()
143 for dict in [vars(conf_args), vars(args)]:
144 res.update({k: v for (k, v) in dict.items() if v is not None})
145
146 return argparse.Namespace(**res)
114 147
115 148
116def save_args(basepath, args, extra={}): 149def save_args(basepath, args, extra={}):
@@ -146,7 +179,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16):
146 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 179 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
147 ) 180 )
148 181
149 pipeline = CLIPGuidedStableDiffusion( 182 pipeline = VlpnStableDiffusion(
150 text_encoder=text_encoder, 183 text_encoder=text_encoder,
151 vae=vae, 184 vae=vae,
152 unet=unet, 185 unet=unet,
@@ -154,7 +187,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16):
154 scheduler=scheduler, 187 scheduler=scheduler,
155 feature_extractor=feature_extractor 188 feature_extractor=feature_extractor
156 ) 189 )
157 pipeline.enable_attention_slicing() 190 # pipeline.enable_attention_slicing()
158 pipeline.to("cuda") 191 pipeline.to("cuda")
159 192
160 print("Pipeline loaded.") 193 print("Pipeline loaded.")
@@ -171,6 +204,13 @@ def generate(output_dir, pipeline, args):
171 204
172 save_args(output_dir, args) 205 save_args(output_dir, args)
173 206
207 if args.image:
208 init_image = Image.open(args.image)
209 if not init_image.mode == "RGB":
210 init_image = init_image.convert("RGB")
211 else:
212 init_image = None
213
174 with autocast("cuda"): 214 with autocast("cuda"):
175 for i in range(args.batch_num): 215 for i in range(args.batch_num):
176 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") 216 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}")
@@ -184,11 +224,15 @@ def generate(output_dir, pipeline, args):
184 num_inference_steps=args.steps, 224 num_inference_steps=args.steps,
185 guidance_scale=args.guidance_scale, 225 guidance_scale=args.guidance_scale,
186 generator=generator, 226 generator=generator,
227 latents=init_image,
187 ).images 228 ).images
188 229
189 for j, image in enumerate(images): 230 for j, image in enumerate(images):
190 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) 231 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg"))
191 232
233 if torch.cuda.is_available():
234 torch.cuda.empty_cache()
235
192 236
193class CmdParse(cmd.Cmd): 237class CmdParse(cmd.Cmd):
194 prompt = 'dream> ' 238 prompt = 'dream> '
@@ -213,7 +257,7 @@ class CmdParse(cmd.Cmd):
213 return True 257 return True
214 258
215 try: 259 try:
216 args = run_parser(self.parser, elements) 260 args = run_parser(self.parser, default_cmds, elements)
217 except SystemExit: 261 except SystemExit:
218 self.parser.print_help() 262 self.parser.print_help()
219 263
@@ -233,7 +277,7 @@ def main():
233 logging.basicConfig(stream=sys.stdout, level=logging.WARN) 277 logging.basicConfig(stream=sys.stdout, level=logging.WARN)
234 278
235 args_parser = create_args_parser() 279 args_parser = create_args_parser()
236 args = run_parser(args_parser) 280 args = run_parser(args_parser, default_args)
237 output_dir = Path(args.output_dir) 281 output_dir = Path(args.output_dir)
238 282
239 pipeline = create_pipeline(args.model, args.scheduler) 283 pipeline = create_pipeline(args.model, args.scheduler)
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index eff74b5..4c793a8 100644
--- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -2,22 +2,29 @@ import inspect
2import warnings 2import warnings
3from typing import List, Optional, Union 3from typing import List, Optional, Union
4 4
5import numpy as np
5import torch 6import torch
6from torch import nn 7import PIL
7from torch.nn import functional as F
8 8
9from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from torchvision import transforms 13from transformers import CLIPTextModel, CLIPTokenizer
14from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward 14from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
16 15
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name 16logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18 17
19 18
20class CLIPGuidedStableDiffusion(DiffusionPipeline): 19def preprocess(image, w, h):
20 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
21 image = np.array(image).astype(np.float32) / 255.0
22 image = image[None].transpose(0, 3, 1, 2)
23 image = torch.from_numpy(image)
24 return 2.0 * image - 1.0
25
26
27class VlpnStableDiffusion(DiffusionPipeline):
21 def __init__( 28 def __init__(
22 self, 29 self,
23 vae: AutoencoderKL, 30 vae: AutoencoderKL,
@@ -83,13 +90,14 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
83 self, 90 self,
84 prompt: Union[str, List[str]], 91 prompt: Union[str, List[str]],
85 negative_prompt: Optional[Union[str, List[str]]] = None, 92 negative_prompt: Optional[Union[str, List[str]]] = None,
93 strength: float = 0.8,
86 height: Optional[int] = 512, 94 height: Optional[int] = 512,
87 width: Optional[int] = 512, 95 width: Optional[int] = 512,
88 num_inference_steps: Optional[int] = 50, 96 num_inference_steps: Optional[int] = 50,
89 guidance_scale: Optional[float] = 7.5, 97 guidance_scale: Optional[float] = 7.5,
90 eta: Optional[float] = 0.0, 98 eta: Optional[float] = 0.0,
91 generator: Optional[torch.Generator] = None, 99 generator: Optional[torch.Generator] = None,
92 latents: Optional[torch.FloatTensor] = None, 100 latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
93 output_type: Optional[str] = "pil", 101 output_type: Optional[str] = "pil",
94 return_dict: bool = True, 102 return_dict: bool = True,
95 ): 103 ):
@@ -99,6 +107,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
99 Args: 107 Args:
100 prompt (`str` or `List[str]`): 108 prompt (`str` or `List[str]`):
101 The prompt or prompts to guide the image generation. 109 The prompt or prompts to guide the image generation.
110 strength (`float`, *optional*, defaults to 0.8):
111 Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
112 `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
113 number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
114 noise will be maximum and the denoising process will run for the full number of iterations specified in
115 `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
102 height (`int`, *optional*, defaults to 512): 116 height (`int`, *optional*, defaults to 512):
103 The height in pixels of the generated image. 117 The height in pixels of the generated image.
104 width (`int`, *optional*, defaults to 512): 118 width (`int`, *optional*, defaults to 512):
@@ -158,6 +172,42 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
158 if height % 8 != 0 or width % 8 != 0: 172 if height % 8 != 0 or width % 8 != 0:
159 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 173 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
160 174
175 if strength < 0 or strength > 1:
176 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
177
178 # set timesteps
179 self.scheduler.set_timesteps(num_inference_steps)
180
181 offset = self.scheduler.config.get("steps_offset", 0)
182
183 if latents is not None and isinstance(latents, PIL.Image.Image):
184 latents = preprocess(latents, width, height)
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents
188 latents = torch.cat([latents] * batch_size)
189
190 # get the original timestep using init_timestep
191 init_timestep = int(num_inference_steps * strength) + offset
192 init_timestep = min(init_timestep, num_inference_steps)
193
194 if isinstance(self.scheduler, LMSDiscreteScheduler):
195 timesteps = torch.tensor(
196 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
197 )
198 elif isinstance(self.scheduler, EulerAScheduler):
199 timesteps = self.scheduler.timesteps[-init_timestep]
200 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
201 else:
202 timesteps = self.scheduler.timesteps[-init_timestep]
203 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
204
205 # add noise to latents using the timesteps
206 noise = torch.randn(latents.shape, generator=generator, device=self.device)
207 latents = self.scheduler.add_noise(latents, noise, timesteps)
208 else:
209 init_timestep = num_inference_steps + offset
210
161 # get prompt text embeddings 211 # get prompt text embeddings
162 text_inputs = self.tokenizer( 212 text_inputs = self.tokenizer(
163 prompt, 213 prompt,
@@ -213,15 +263,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
213 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 263 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
214 latents = latents.to(self.device) 264 latents = latents.to(self.device)
215 265
216 # set timesteps 266 t_start = max(num_inference_steps - init_timestep + offset, 0)
217 self.scheduler.set_timesteps(num_inference_steps)
218 267
219 # Some schedulers like PNDM have timesteps as arrays 268 # Some schedulers like PNDM have timesteps as arrays
220 # It's more optimzed to move all timesteps to correct device beforehand 269 # It's more optimzed to move all timesteps to correct device beforehand
221 if torch.is_tensor(self.scheduler.timesteps): 270 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
222 timesteps_tensor = self.scheduler.timesteps.to(self.device)
223 else:
224 timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
225 271
226 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 272 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
227 if isinstance(self.scheduler, LMSDiscreteScheduler): 273 if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -244,10 +290,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
244 extra_step_kwargs["generator"] = generator 290 extra_step_kwargs["generator"] = generator
245 291
246 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 292 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
293 t_index = t_start + i
294
247 # expand the latents if we are doing classifier free guidance 295 # expand the latents if we are doing classifier free guidance
248 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 296 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249 if isinstance(self.scheduler, LMSDiscreteScheduler): 297 if isinstance(self.scheduler, LMSDiscreteScheduler):
250 sigma = self.scheduler.sigmas[i] 298 sigma = self.scheduler.sigmas[t_index]
251 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 299 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
252 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) 300 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
253 301
@@ -270,10 +318,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
270 318
271 # compute the previous noisy sample x_t -> x_t-1 319 # compute the previous noisy sample x_t -> x_t-1
272 if isinstance(self.scheduler, LMSDiscreteScheduler): 320 if isinstance(self.scheduler, LMSDiscreteScheduler):
273 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 321 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
274 elif isinstance(self.scheduler, EulerAScheduler): 322 elif isinstance(self.scheduler, EulerAScheduler):
275 if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error 323 if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error
276 t_prev = self.scheduler.timesteps[i+1] 324 t_prev = self.scheduler.timesteps[t_index+1]
277 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample 325 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
278 else: 326 else:
279 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 327 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample