summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py18
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py18
-rw-r--r--schedulers/scheduling_euler_a.py9
3 files changed, 28 insertions, 17 deletions
diff --git a/infer.py b/infer.py
index f2c380f..b15b17f 100644
--- a/infer.py
+++ b/infer.py
@@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler
19default_args = { 19default_args = {
20 "model": None, 20 "model": None,
21 "scheduler": "euler_a", 21 "scheduler": "euler_a",
22 "precision": "bf16",
22 "output_dir": "output/inference", 23 "output_dir": "output/inference",
23 "config": None, 24 "config": None,
24} 25}
@@ -28,7 +29,7 @@ default_cmds = {
28 "prompt": None, 29 "prompt": None,
29 "negative_prompt": None, 30 "negative_prompt": None,
30 "image": None, 31 "image": None,
31 "image_strength": .3, 32 "image_noise": .7,
32 "width": 512, 33 "width": 512,
33 "height": 512, 34 "height": 512,
34 "batch_size": 1, 35 "batch_size": 1,
@@ -63,6 +64,11 @@ def create_args_parser():
63 choices=["plms", "ddim", "klms", "euler_a"], 64 choices=["plms", "ddim", "klms", "euler_a"],
64 ) 65 )
65 parser.add_argument( 66 parser.add_argument(
67 "--precision",
68 type=str,
69 choices=["fp32", "fp16", "bf16"],
70 )
71 parser.add_argument(
66 "--output_dir", 72 "--output_dir",
67 type=str, 73 type=str,
68 ) 74 )
@@ -91,7 +97,7 @@ def create_cmd_parser():
91 type=str, 97 type=str,
92 ) 98 )
93 parser.add_argument( 99 parser.add_argument(
94 "--image_strength", 100 "--image_noise",
95 type=float, 101 type=float,
96 ) 102 )
97 parser.add_argument( 103 parser.add_argument(
@@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}):
153 json.dump(info, f, indent=4) 159 json.dump(info, f, indent=4)
154 160
155 161
156def create_pipeline(model, scheduler, dtype=torch.bfloat16): 162def create_pipeline(model, scheduler, dtype):
157 print("Loading Stable Diffusion pipeline...") 163 print("Loading Stable Diffusion pipeline...")
158 164
159 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) 165 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype)
@@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args):
225 guidance_scale=args.guidance_scale, 231 guidance_scale=args.guidance_scale,
226 generator=generator, 232 generator=generator,
227 latents=init_image, 233 latents=init_image,
228 strength=args.image_strength, 234 strength=args.image_noise,
229 ).images 235 ).images
230 236
231 for j, image in enumerate(images): 237 for j, image in enumerate(images):
@@ -279,9 +285,11 @@ def main():
279 285
280 args_parser = create_args_parser() 286 args_parser = create_args_parser()
281 args = run_parser(args_parser, default_args) 287 args = run_parser(args_parser, default_args)
288
282 output_dir = Path(args.output_dir) 289 output_dir = Path(args.output_dir)
290 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
283 291
284 pipeline = create_pipeline(args.model, args.scheduler) 292 pipeline = create_pipeline(args.model, args.scheduler, dtype)
285 cmd_parser = create_cmd_parser() 293 cmd_parser = create_cmd_parser()
286 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 294 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
287 cmd_prompt.cmdloop() 295 cmd_prompt.cmdloop()
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index b4c85e9..8fbe5f9 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline):
223 # Unlike in other pipelines, latents need to be generated in the target device 223 # Unlike in other pipelines, latents need to be generated in the target device
224 # for 1-to-1 results reproducibility with the CompVis implementation. 224 # for 1-to-1 results reproducibility with the CompVis implementation.
225 # However this currently doesn't work in `mps`. 225 # However this currently doesn't work in `mps`.
226 latents_device = "cpu" if self.device.type == "mps" else self.device 226 latents_dtype = text_embeddings.dtype
227 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) 227 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
228 if latents is None: 228 if latents is None:
229 latents = torch.randn( 229 if self.device.type == "mps":
230 latents_shape, 230 # randn does not exist on mps
231 generator=generator, 231 latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
232 device=latents_device, 232 self.device
233 dtype=text_embeddings.dtype, 233 )
234 ) 234 else:
235 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
235 elif isinstance(latents, PIL.Image.Image): 236 elif isinstance(latents, PIL.Image.Image):
236 latents = preprocess(latents, width, height) 237 latents = preprocess(latents, width, height)
237 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist 238 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
@@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
259 else: 260 else:
260 if latents.shape != latents_shape: 261 if latents.shape != latents_shape:
261 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 262 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
262 latents = latents.to(self.device) 263 if latents.device != self.device:
264 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}")
263 265
264 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 266 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
265 if ensure_sigma: 267 if ensure_sigma:
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index a2d0e9f..d7fea85 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None):
36 if n is None: 36 if n is None:
37 return append_zero(sigmas.flip(0)) 37 return append_zero(sigmas.flip(0))
38 t_max = len(sigmas) - 1 # = 999 38 t_max = len(sigmas) - 1 # = 999
39 t = torch.linspace(t_max, 0, n, device=sigmas.device) 39 t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype)
40 return append_zero(t_to_sigma(t, sigmas)) 40 return append_zero(t_to_sigma(t, sigmas))
41 41
42# from k_samplers utils.py 42# from k_samplers utils.py
@@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None):
91 91
92 92
93def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): 93def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs):
94 sigma = sigma.to(Unet.device) 94 sigma = sigma.to(dtype=input.dtype, device=Unet.device)
95 DSsigmas = DSsigmas.to(Unet.device) 95 DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device)
96 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] 96 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
97 # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}")
97 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), 98 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas),
98 encoder_hidden_states=kwargs['cond']).sample 99 encoder_hidden_states=kwargs['cond']).sample
99 return input + eps * c_out 100 return input + eps * c_out
@@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
226 d = to_d(latents, s, model_output) 227 d = to_d(latents, s, model_output)
227 dt = sigma_down - s 228 dt = sigma_down - s
228 latents = latents + d * dt 229 latents = latents + d * dt
229 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 230 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
230 generator=generator) * sigma_up 231 generator=generator) * sigma_up
231 232
232 return SchedulerOutput(prev_sample=latents) 233 return SchedulerOutput(prev_sample=latents)