diff options
| -rw-r--r-- | infer.py | 3 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 71 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 24 |
3 files changed, 38 insertions, 60 deletions
| @@ -28,7 +28,7 @@ default_cmds = { | |||
| 28 | "prompt": None, | 28 | "prompt": None, |
| 29 | "negative_prompt": None, | 29 | "negative_prompt": None, |
| 30 | "image": None, | 30 | "image": None, |
| 31 | "image_strength": .7, | 31 | "image_strength": .3, |
| 32 | "width": 512, | 32 | "width": 512, |
| 33 | "height": 512, | 33 | "height": 512, |
| 34 | "batch_size": 1, | 34 | "batch_size": 1, |
| @@ -225,6 +225,7 @@ def generate(output_dir, pipeline, args): | |||
| 225 | guidance_scale=args.guidance_scale, | 225 | guidance_scale=args.guidance_scale, |
| 226 | generator=generator, | 226 | generator=generator, |
| 227 | latents=init_image, | 227 | latents=init_image, |
| 228 | strength=args.image_strength, | ||
| 228 | ).images | 229 | ).images |
| 229 | 230 | ||
| 230 | for j, image in enumerate(images): | 231 | for j, image in enumerate(images): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a8ecedf..b4c85e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -178,35 +178,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 178 | # set timesteps | 178 | # set timesteps |
| 179 | self.scheduler.set_timesteps(num_inference_steps) | 179 | self.scheduler.set_timesteps(num_inference_steps) |
| 180 | 180 | ||
| 181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 182 | |||
| 183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
| 184 | latents = preprocess(latents, width, height) | ||
| 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
| 186 | latents = latent_dist.sample(generator=generator) | ||
| 187 | latents = 0.18215 * latents | ||
| 188 | |||
| 189 | # expand init_latents for batch_size | ||
| 190 | latents = torch.cat([latents] * batch_size) | ||
| 191 | |||
| 192 | # get the original timestep using init_timestep | ||
| 193 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 194 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 195 | |||
| 196 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 197 | timesteps = torch.tensor( | ||
| 198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 199 | ) | ||
| 200 | else: | ||
| 201 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
| 203 | |||
| 204 | # add noise to latents using the timesteps | ||
| 205 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
| 206 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
| 207 | else: | ||
| 208 | init_timestep = num_inference_steps + offset | ||
| 209 | |||
| 210 | # get prompt text embeddings | 181 | # get prompt text embeddings |
| 211 | text_inputs = self.tokenizer( | 182 | text_inputs = self.tokenizer( |
| 212 | prompt, | 183 | prompt, |
| @@ -243,6 +214,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 243 | # to avoid doing two forward passes | 214 | # to avoid doing two forward passes |
| 244 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 215 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| 245 | 216 | ||
| 217 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 218 | init_timestep = num_inference_steps + offset | ||
| 219 | ensure_sigma = not isinstance(latents, PIL.Image.Image) | ||
| 220 | |||
| 246 | # get the initial random noise unless the user supplied it | 221 | # get the initial random noise unless the user supplied it |
| 247 | 222 | ||
| 248 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
| @@ -257,23 +232,48 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 257 | device=latents_device, | 232 | device=latents_device, |
| 258 | dtype=text_embeddings.dtype, | 233 | dtype=text_embeddings.dtype, |
| 259 | ) | 234 | ) |
| 235 | elif isinstance(latents, PIL.Image.Image): | ||
| 236 | latents = preprocess(latents, width, height) | ||
| 237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
| 238 | latents = latent_dist.sample(generator=generator) | ||
| 239 | latents = 0.18215 * latents | ||
| 240 | |||
| 241 | # expand init_latents for batch_size | ||
| 242 | latents = torch.cat([latents] * batch_size) | ||
| 243 | |||
| 244 | # get the original timestep using init_timestep | ||
| 245 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 246 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 247 | |||
| 248 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 249 | timesteps = torch.tensor( | ||
| 250 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 251 | ) | ||
| 252 | else: | ||
| 253 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 254 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
| 255 | |||
| 256 | # add noise to latents using the timesteps | ||
| 257 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
| 258 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
| 260 | else: | 259 | else: |
| 261 | if latents.shape != latents_shape: | 260 | if latents.shape != latents_shape: |
| 262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
| 263 | latents = latents.to(self.device) | 262 | latents = latents.to(self.device) |
| 264 | 263 | ||
| 264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| 265 | if ensure_sigma: | ||
| 266 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 267 | latents = latents * self.scheduler.sigmas[0] | ||
| 268 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 269 | latents = latents * self.scheduler.sigmas[0] | ||
| 270 | |||
| 265 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 271 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
| 266 | 272 | ||
| 267 | # Some schedulers like PNDM have timesteps as arrays | 273 | # Some schedulers like PNDM have timesteps as arrays |
| 268 | # It's more optimzed to move all timesteps to correct device beforehand | 274 | # It's more optimzed to move all timesteps to correct device beforehand |
| 269 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | 275 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
| 270 | 276 | ||
| 271 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 273 | latents = latents * self.scheduler.sigmas[0] | ||
| 274 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 275 | latents = latents * self.scheduler.sigmas[0] | ||
| 276 | |||
| 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| 279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| @@ -292,6 +292,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 292 | 292 | ||
| 293 | # expand the latents if we are doing classifier free guidance | 293 | # expand the latents if we are doing classifier free guidance |
| 294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 295 | |||
| 295 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 296 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 296 | sigma = self.scheduler.sigmas[t_index] | 297 | sigma = self.scheduler.sigmas[t_index] |
| 297 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 298 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 1b1c9cf..a2d0e9f 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -191,27 +191,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
| 192 | self.timesteps = np.arange(0, self.num_inference_steps) | 192 | self.timesteps = np.arange(0, self.num_inference_steps) |
| 193 | 193 | ||
| 194 | def add_noise_to_input( | ||
| 195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | ||
| 196 | ) -> Tuple[torch.FloatTensor, float]: | ||
| 197 | """ | ||
| 198 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a | ||
| 199 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
| 200 | |||
| 201 | TODO Args: | ||
| 202 | """ | ||
| 203 | if self.config.s_min <= sigma <= self.config.s_max: | ||
| 204 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
| 205 | else: | ||
| 206 | gamma = 0 | ||
| 207 | |||
| 208 | # sample eps ~ N(0, S_noise^2 * I) | ||
| 209 | eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
| 210 | sigma_hat = sigma + gamma * sigma | ||
| 211 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
| 212 | |||
| 213 | return sample_hat, sigma_hat | ||
| 214 | |||
| 215 | def step( | 194 | def step( |
| 216 | self, | 195 | self, |
| 217 | model_output: torch.FloatTensor, | 196 | model_output: torch.FloatTensor, |
| @@ -219,9 +198,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 219 | timestep_prev: int, | 198 | timestep_prev: int, |
| 220 | sample: torch.FloatTensor, | 199 | sample: torch.FloatTensor, |
| 221 | generator: None, | 200 | generator: None, |
| 222 | # ,sigma_hat: float, | ||
| 223 | # sigma_prev: float, | ||
| 224 | # sample_hat: torch.FloatTensor, | ||
| 225 | return_dict: bool = True, | 201 | return_dict: bool = True, |
| 226 | ) -> Union[SchedulerOutput, Tuple]: | 202 | ) -> Union[SchedulerOutput, Tuple]: |
| 227 | """ | 203 | """ |
