diff options
-rw-r--r-- | infer.py | 18 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 18 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 9 |
3 files changed, 28 insertions, 17 deletions
@@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler | |||
19 | default_args = { | 19 | default_args = { |
20 | "model": None, | 20 | "model": None, |
21 | "scheduler": "euler_a", | 21 | "scheduler": "euler_a", |
22 | "precision": "bf16", | ||
22 | "output_dir": "output/inference", | 23 | "output_dir": "output/inference", |
23 | "config": None, | 24 | "config": None, |
24 | } | 25 | } |
@@ -28,7 +29,7 @@ default_cmds = { | |||
28 | "prompt": None, | 29 | "prompt": None, |
29 | "negative_prompt": None, | 30 | "negative_prompt": None, |
30 | "image": None, | 31 | "image": None, |
31 | "image_strength": .3, | 32 | "image_noise": .7, |
32 | "width": 512, | 33 | "width": 512, |
33 | "height": 512, | 34 | "height": 512, |
34 | "batch_size": 1, | 35 | "batch_size": 1, |
@@ -63,6 +64,11 @@ def create_args_parser(): | |||
63 | choices=["plms", "ddim", "klms", "euler_a"], | 64 | choices=["plms", "ddim", "klms", "euler_a"], |
64 | ) | 65 | ) |
65 | parser.add_argument( | 66 | parser.add_argument( |
67 | "--precision", | ||
68 | type=str, | ||
69 | choices=["fp32", "fp16", "bf16"], | ||
70 | ) | ||
71 | parser.add_argument( | ||
66 | "--output_dir", | 72 | "--output_dir", |
67 | type=str, | 73 | type=str, |
68 | ) | 74 | ) |
@@ -91,7 +97,7 @@ def create_cmd_parser(): | |||
91 | type=str, | 97 | type=str, |
92 | ) | 98 | ) |
93 | parser.add_argument( | 99 | parser.add_argument( |
94 | "--image_strength", | 100 | "--image_noise", |
95 | type=float, | 101 | type=float, |
96 | ) | 102 | ) |
97 | parser.add_argument( | 103 | parser.add_argument( |
@@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}): | |||
153 | json.dump(info, f, indent=4) | 159 | json.dump(info, f, indent=4) |
154 | 160 | ||
155 | 161 | ||
156 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): | 162 | def create_pipeline(model, scheduler, dtype): |
157 | print("Loading Stable Diffusion pipeline...") | 163 | print("Loading Stable Diffusion pipeline...") |
158 | 164 | ||
159 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) | 165 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
@@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args): | |||
225 | guidance_scale=args.guidance_scale, | 231 | guidance_scale=args.guidance_scale, |
226 | generator=generator, | 232 | generator=generator, |
227 | latents=init_image, | 233 | latents=init_image, |
228 | strength=args.image_strength, | 234 | strength=args.image_noise, |
229 | ).images | 235 | ).images |
230 | 236 | ||
231 | for j, image in enumerate(images): | 237 | for j, image in enumerate(images): |
@@ -279,9 +285,11 @@ def main(): | |||
279 | 285 | ||
280 | args_parser = create_args_parser() | 286 | args_parser = create_args_parser() |
281 | args = run_parser(args_parser, default_args) | 287 | args = run_parser(args_parser, default_args) |
288 | |||
282 | output_dir = Path(args.output_dir) | 289 | output_dir = Path(args.output_dir) |
290 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | ||
283 | 291 | ||
284 | pipeline = create_pipeline(args.model, args.scheduler) | 292 | pipeline = create_pipeline(args.model, args.scheduler, dtype) |
285 | cmd_parser = create_cmd_parser() | 293 | cmd_parser = create_cmd_parser() |
286 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 294 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
287 | cmd_prompt.cmdloop() | 295 | cmd_prompt.cmdloop() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b4c85e9..8fbe5f9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
223 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
224 | # for 1-to-1 results reproducibility with the CompVis implementation. | 224 | # for 1-to-1 results reproducibility with the CompVis implementation. |
225 | # However this currently doesn't work in `mps`. | 225 | # However this currently doesn't work in `mps`. |
226 | latents_device = "cpu" if self.device.type == "mps" else self.device | 226 | latents_dtype = text_embeddings.dtype |
227 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | 227 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) |
228 | if latents is None: | 228 | if latents is None: |
229 | latents = torch.randn( | 229 | if self.device.type == "mps": |
230 | latents_shape, | 230 | # randn does not exist on mps |
231 | generator=generator, | 231 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( |
232 | device=latents_device, | 232 | self.device |
233 | dtype=text_embeddings.dtype, | 233 | ) |
234 | ) | 234 | else: |
235 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
235 | elif isinstance(latents, PIL.Image.Image): | 236 | elif isinstance(latents, PIL.Image.Image): |
236 | latents = preprocess(latents, width, height) | 237 | latents = preprocess(latents, width, height) |
237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 238 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist |
@@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
259 | else: | 260 | else: |
260 | if latents.shape != latents_shape: | 261 | if latents.shape != latents_shape: |
261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
262 | latents = latents.to(self.device) | 263 | if latents.device != self.device: |
264 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") | ||
263 | 265 | ||
264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | 266 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas |
265 | if ensure_sigma: | 267 | if ensure_sigma: |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index a2d0e9f..d7fea85 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None): | |||
36 | if n is None: | 36 | if n is None: |
37 | return append_zero(sigmas.flip(0)) | 37 | return append_zero(sigmas.flip(0)) |
38 | t_max = len(sigmas) - 1 # = 999 | 38 | t_max = len(sigmas) - 1 # = 999 |
39 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 39 | t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) |
40 | return append_zero(t_to_sigma(t, sigmas)) | 40 | return append_zero(t_to_sigma(t, sigmas)) |
41 | 41 | ||
42 | # from k_samplers utils.py | 42 | # from k_samplers utils.py |
@@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): | |||
91 | 91 | ||
92 | 92 | ||
93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): | 93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
94 | sigma = sigma.to(Unet.device) | 94 | sigma = sigma.to(dtype=input.dtype, device=Unet.device) |
95 | DSsigmas = DSsigmas.to(Unet.device) | 95 | DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) |
96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
97 | # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") | ||
97 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), | 98 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
98 | encoder_hidden_states=kwargs['cond']).sample | 99 | encoder_hidden_states=kwargs['cond']).sample |
99 | return input + eps * c_out | 100 | return input + eps * c_out |
@@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
226 | d = to_d(latents, s, model_output) | 227 | d = to_d(latents, s, model_output) |
227 | dt = sigma_down - s | 228 | dt = sigma_down - s |
228 | latents = latents + d * dt | 229 | latents = latents + d * dt |
229 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 230 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, |
230 | generator=generator) * sigma_up | 231 | generator=generator) * sigma_up |
231 | 232 | ||
232 | return SchedulerOutput(prev_sample=latents) | 233 | return SchedulerOutput(prev_sample=latents) |