diff options
-rw-r--r-- | infer.py | 3 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 71 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 24 |
3 files changed, 38 insertions, 60 deletions
@@ -28,7 +28,7 @@ default_cmds = { | |||
28 | "prompt": None, | 28 | "prompt": None, |
29 | "negative_prompt": None, | 29 | "negative_prompt": None, |
30 | "image": None, | 30 | "image": None, |
31 | "image_strength": .7, | 31 | "image_strength": .3, |
32 | "width": 512, | 32 | "width": 512, |
33 | "height": 512, | 33 | "height": 512, |
34 | "batch_size": 1, | 34 | "batch_size": 1, |
@@ -225,6 +225,7 @@ def generate(output_dir, pipeline, args): | |||
225 | guidance_scale=args.guidance_scale, | 225 | guidance_scale=args.guidance_scale, |
226 | generator=generator, | 226 | generator=generator, |
227 | latents=init_image, | 227 | latents=init_image, |
228 | strength=args.image_strength, | ||
228 | ).images | 229 | ).images |
229 | 230 | ||
230 | for j, image in enumerate(images): | 231 | for j, image in enumerate(images): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a8ecedf..b4c85e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -178,35 +178,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
178 | # set timesteps | 178 | # set timesteps |
179 | self.scheduler.set_timesteps(num_inference_steps) | 179 | self.scheduler.set_timesteps(num_inference_steps) |
180 | 180 | ||
181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
182 | |||
183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
184 | latents = preprocess(latents, width, height) | ||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
186 | latents = latent_dist.sample(generator=generator) | ||
187 | latents = 0.18215 * latents | ||
188 | |||
189 | # expand init_latents for batch_size | ||
190 | latents = torch.cat([latents] * batch_size) | ||
191 | |||
192 | # get the original timestep using init_timestep | ||
193 | init_timestep = int(num_inference_steps * strength) + offset | ||
194 | init_timestep = min(init_timestep, num_inference_steps) | ||
195 | |||
196 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
197 | timesteps = torch.tensor( | ||
198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
199 | ) | ||
200 | else: | ||
201 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
203 | |||
204 | # add noise to latents using the timesteps | ||
205 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
206 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
207 | else: | ||
208 | init_timestep = num_inference_steps + offset | ||
209 | |||
210 | # get prompt text embeddings | 181 | # get prompt text embeddings |
211 | text_inputs = self.tokenizer( | 182 | text_inputs = self.tokenizer( |
212 | prompt, | 183 | prompt, |
@@ -243,6 +214,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
243 | # to avoid doing two forward passes | 214 | # to avoid doing two forward passes |
244 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 215 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
245 | 216 | ||
217 | offset = self.scheduler.config.get("steps_offset", 0) | ||
218 | init_timestep = num_inference_steps + offset | ||
219 | ensure_sigma = not isinstance(latents, PIL.Image.Image) | ||
220 | |||
246 | # get the initial random noise unless the user supplied it | 221 | # get the initial random noise unless the user supplied it |
247 | 222 | ||
248 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
@@ -257,23 +232,48 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
257 | device=latents_device, | 232 | device=latents_device, |
258 | dtype=text_embeddings.dtype, | 233 | dtype=text_embeddings.dtype, |
259 | ) | 234 | ) |
235 | elif isinstance(latents, PIL.Image.Image): | ||
236 | latents = preprocess(latents, width, height) | ||
237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
238 | latents = latent_dist.sample(generator=generator) | ||
239 | latents = 0.18215 * latents | ||
240 | |||
241 | # expand init_latents for batch_size | ||
242 | latents = torch.cat([latents] * batch_size) | ||
243 | |||
244 | # get the original timestep using init_timestep | ||
245 | init_timestep = int(num_inference_steps * strength) + offset | ||
246 | init_timestep = min(init_timestep, num_inference_steps) | ||
247 | |||
248 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
249 | timesteps = torch.tensor( | ||
250 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
251 | ) | ||
252 | else: | ||
253 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
254 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
255 | |||
256 | # add noise to latents using the timesteps | ||
257 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
258 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
260 | else: | 259 | else: |
261 | if latents.shape != latents_shape: | 260 | if latents.shape != latents_shape: |
262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
263 | latents = latents.to(self.device) | 262 | latents = latents.to(self.device) |
264 | 263 | ||
264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
265 | if ensure_sigma: | ||
266 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
267 | latents = latents * self.scheduler.sigmas[0] | ||
268 | elif isinstance(self.scheduler, EulerAScheduler): | ||
269 | latents = latents * self.scheduler.sigmas[0] | ||
270 | |||
265 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 271 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
266 | 272 | ||
267 | # Some schedulers like PNDM have timesteps as arrays | 273 | # Some schedulers like PNDM have timesteps as arrays |
268 | # It's more optimzed to move all timesteps to correct device beforehand | 274 | # It's more optimzed to move all timesteps to correct device beforehand |
269 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | 275 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
270 | 276 | ||
271 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
273 | latents = latents * self.scheduler.sigmas[0] | ||
274 | elif isinstance(self.scheduler, EulerAScheduler): | ||
275 | latents = latents * self.scheduler.sigmas[0] | ||
276 | |||
277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
@@ -292,6 +292,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
292 | 292 | ||
293 | # expand the latents if we are doing classifier free guidance | 293 | # expand the latents if we are doing classifier free guidance |
294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
295 | |||
295 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 296 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
296 | sigma = self.scheduler.sigmas[t_index] | 297 | sigma = self.scheduler.sigmas[t_index] |
297 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 298 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 1b1c9cf..a2d0e9f 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -191,27 +191,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
192 | self.timesteps = np.arange(0, self.num_inference_steps) | 192 | self.timesteps = np.arange(0, self.num_inference_steps) |
193 | 193 | ||
194 | def add_noise_to_input( | ||
195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | ||
196 | ) -> Tuple[torch.FloatTensor, float]: | ||
197 | """ | ||
198 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a | ||
199 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
200 | |||
201 | TODO Args: | ||
202 | """ | ||
203 | if self.config.s_min <= sigma <= self.config.s_max: | ||
204 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
205 | else: | ||
206 | gamma = 0 | ||
207 | |||
208 | # sample eps ~ N(0, S_noise^2 * I) | ||
209 | eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
210 | sigma_hat = sigma + gamma * sigma | ||
211 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
212 | |||
213 | return sample_hat, sigma_hat | ||
214 | |||
215 | def step( | 194 | def step( |
216 | self, | 195 | self, |
217 | model_output: torch.FloatTensor, | 196 | model_output: torch.FloatTensor, |
@@ -219,9 +198,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
219 | timestep_prev: int, | 198 | timestep_prev: int, |
220 | sample: torch.FloatTensor, | 199 | sample: torch.FloatTensor, |
221 | generator: None, | 200 | generator: None, |
222 | # ,sigma_hat: float, | ||
223 | # sigma_prev: float, | ||
224 | # sample_hat: torch.FloatTensor, | ||
225 | return_dict: bool = True, | 201 | return_dict: bool = True, |
226 | ) -> Union[SchedulerOutput, Tuple]: | 202 | ) -> Union[SchedulerOutput, Tuple]: |
227 | """ | 203 | """ |