diff options
| -rw-r--r-- | infer.py | 18 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 18 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 9 |
3 files changed, 28 insertions, 17 deletions
| @@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler | |||
| 19 | default_args = { | 19 | default_args = { |
| 20 | "model": None, | 20 | "model": None, |
| 21 | "scheduler": "euler_a", | 21 | "scheduler": "euler_a", |
| 22 | "precision": "bf16", | ||
| 22 | "output_dir": "output/inference", | 23 | "output_dir": "output/inference", |
| 23 | "config": None, | 24 | "config": None, |
| 24 | } | 25 | } |
| @@ -28,7 +29,7 @@ default_cmds = { | |||
| 28 | "prompt": None, | 29 | "prompt": None, |
| 29 | "negative_prompt": None, | 30 | "negative_prompt": None, |
| 30 | "image": None, | 31 | "image": None, |
| 31 | "image_strength": .3, | 32 | "image_noise": .7, |
| 32 | "width": 512, | 33 | "width": 512, |
| 33 | "height": 512, | 34 | "height": 512, |
| 34 | "batch_size": 1, | 35 | "batch_size": 1, |
| @@ -63,6 +64,11 @@ def create_args_parser(): | |||
| 63 | choices=["plms", "ddim", "klms", "euler_a"], | 64 | choices=["plms", "ddim", "klms", "euler_a"], |
| 64 | ) | 65 | ) |
| 65 | parser.add_argument( | 66 | parser.add_argument( |
| 67 | "--precision", | ||
| 68 | type=str, | ||
| 69 | choices=["fp32", "fp16", "bf16"], | ||
| 70 | ) | ||
| 71 | parser.add_argument( | ||
| 66 | "--output_dir", | 72 | "--output_dir", |
| 67 | type=str, | 73 | type=str, |
| 68 | ) | 74 | ) |
| @@ -91,7 +97,7 @@ def create_cmd_parser(): | |||
| 91 | type=str, | 97 | type=str, |
| 92 | ) | 98 | ) |
| 93 | parser.add_argument( | 99 | parser.add_argument( |
| 94 | "--image_strength", | 100 | "--image_noise", |
| 95 | type=float, | 101 | type=float, |
| 96 | ) | 102 | ) |
| 97 | parser.add_argument( | 103 | parser.add_argument( |
| @@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}): | |||
| 153 | json.dump(info, f, indent=4) | 159 | json.dump(info, f, indent=4) |
| 154 | 160 | ||
| 155 | 161 | ||
| 156 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): | 162 | def create_pipeline(model, scheduler, dtype): |
| 157 | print("Loading Stable Diffusion pipeline...") | 163 | print("Loading Stable Diffusion pipeline...") |
| 158 | 164 | ||
| 159 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) | 165 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
| @@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args): | |||
| 225 | guidance_scale=args.guidance_scale, | 231 | guidance_scale=args.guidance_scale, |
| 226 | generator=generator, | 232 | generator=generator, |
| 227 | latents=init_image, | 233 | latents=init_image, |
| 228 | strength=args.image_strength, | 234 | strength=args.image_noise, |
| 229 | ).images | 235 | ).images |
| 230 | 236 | ||
| 231 | for j, image in enumerate(images): | 237 | for j, image in enumerate(images): |
| @@ -279,9 +285,11 @@ def main(): | |||
| 279 | 285 | ||
| 280 | args_parser = create_args_parser() | 286 | args_parser = create_args_parser() |
| 281 | args = run_parser(args_parser, default_args) | 287 | args = run_parser(args_parser, default_args) |
| 288 | |||
| 282 | output_dir = Path(args.output_dir) | 289 | output_dir = Path(args.output_dir) |
| 290 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | ||
| 283 | 291 | ||
| 284 | pipeline = create_pipeline(args.model, args.scheduler) | 292 | pipeline = create_pipeline(args.model, args.scheduler, dtype) |
| 285 | cmd_parser = create_cmd_parser() | 293 | cmd_parser = create_cmd_parser() |
| 286 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 294 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
| 287 | cmd_prompt.cmdloop() | 295 | cmd_prompt.cmdloop() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b4c85e9..8fbe5f9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 223 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
| 224 | # for 1-to-1 results reproducibility with the CompVis implementation. | 224 | # for 1-to-1 results reproducibility with the CompVis implementation. |
| 225 | # However this currently doesn't work in `mps`. | 225 | # However this currently doesn't work in `mps`. |
| 226 | latents_device = "cpu" if self.device.type == "mps" else self.device | 226 | latents_dtype = text_embeddings.dtype |
| 227 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | 227 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) |
| 228 | if latents is None: | 228 | if latents is None: |
| 229 | latents = torch.randn( | 229 | if self.device.type == "mps": |
| 230 | latents_shape, | 230 | # randn does not exist on mps |
| 231 | generator=generator, | 231 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( |
| 232 | device=latents_device, | 232 | self.device |
| 233 | dtype=text_embeddings.dtype, | 233 | ) |
| 234 | ) | 234 | else: |
| 235 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
| 235 | elif isinstance(latents, PIL.Image.Image): | 236 | elif isinstance(latents, PIL.Image.Image): |
| 236 | latents = preprocess(latents, width, height) | 237 | latents = preprocess(latents, width, height) |
| 237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 238 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist |
| @@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 259 | else: | 260 | else: |
| 260 | if latents.shape != latents_shape: | 261 | if latents.shape != latents_shape: |
| 261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
| 262 | latents = latents.to(self.device) | 263 | if latents.device != self.device: |
| 264 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") | ||
| 263 | 265 | ||
| 264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | 266 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas |
| 265 | if ensure_sigma: | 267 | if ensure_sigma: |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index a2d0e9f..d7fea85 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None): | |||
| 36 | if n is None: | 36 | if n is None: |
| 37 | return append_zero(sigmas.flip(0)) | 37 | return append_zero(sigmas.flip(0)) |
| 38 | t_max = len(sigmas) - 1 # = 999 | 38 | t_max = len(sigmas) - 1 # = 999 |
| 39 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 39 | t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) |
| 40 | return append_zero(t_to_sigma(t, sigmas)) | 40 | return append_zero(t_to_sigma(t, sigmas)) |
| 41 | 41 | ||
| 42 | # from k_samplers utils.py | 42 | # from k_samplers utils.py |
| @@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): | |||
| 91 | 91 | ||
| 92 | 92 | ||
| 93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): | 93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
| 94 | sigma = sigma.to(Unet.device) | 94 | sigma = sigma.to(dtype=input.dtype, device=Unet.device) |
| 95 | DSsigmas = DSsigmas.to(Unet.device) | 95 | DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) |
| 96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
| 97 | # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") | ||
| 97 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), | 98 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
| 98 | encoder_hidden_states=kwargs['cond']).sample | 99 | encoder_hidden_states=kwargs['cond']).sample |
| 99 | return input + eps * c_out | 100 | return input + eps * c_out |
| @@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 226 | d = to_d(latents, s, model_output) | 227 | d = to_d(latents, s, model_output) |
| 227 | dt = sigma_down - s | 228 | dt = sigma_down - s |
| 228 | latents = latents + d * dt | 229 | latents = latents + d * dt |
| 229 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 230 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, |
| 230 | generator=generator) * sigma_up | 231 | generator=generator) * sigma_up |
| 231 | 232 | ||
| 232 | return SchedulerOutput(prev_sample=latents) | 233 | return SchedulerOutput(prev_sample=latents) |
